mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-19 14:28:59 +00:00
Initial commit
This commit is contained in:
89
python/cutlass_api/README.md
Normal file
89
python/cutlass_api/README.md
Normal file
@@ -0,0 +1,89 @@
|
||||
# CUTLASS API
|
||||
|
||||
**NOTE: This is an experimental/early-access version of the CUTLASS API. All APIs, names, and paths are subject to change.**
|
||||
|
||||
The CUTLASS API provides high-level, universal interfaces to discover, compile,
|
||||
and execute GEMMs (including grouped GEMMs and scaled GEMMs) targeting NVIDIA GPUs.
|
||||
It allows GEMM kernels written in different DSLs to be integrated and discovered under
|
||||
a uniform API.
|
||||
|
||||
```python
|
||||
import cutlass_api
|
||||
import torch
|
||||
|
||||
A, B, out = (torch.randn(128, 128, device="cuda", dtype=torch.float16) for _ in range(3))
|
||||
|
||||
# Create arguments for the GEMM operation
|
||||
args = cutlass_api.arguments.GemmArguments(A, B, out, accumulator_type=torch.float32)
|
||||
|
||||
# Query for kernels that support our provided arguments on SM100
|
||||
kernels = cutlass_api.get_kernels(args, cc=100)
|
||||
|
||||
# JIT compile and execute one of the returned kernels
|
||||
kernels[0].run(args)
|
||||
```
|
||||
|
||||
## Deep dive
|
||||
|
||||
To learn more about the API, follow the in-depth tutorials in the [examples directory](./examples).
|
||||
|
||||
## Directory structure
|
||||
* [cutlass_api](./cutlass_api/): source for the CUTLASS API
|
||||
* [examples](./examples/): examples of using the CUTLASS API
|
||||
* [test](./test/): tests of the CUTLASS API
|
||||
|
||||
## Installation
|
||||
There is currently no wheel for the CUTLASS API.
|
||||
Install in editable mode from the `python/cutlass_api` directory of the CUTLASS repository:
|
||||
```bash
|
||||
# From the root of this README file
|
||||
|
||||
# Install required dependencies
|
||||
pip install -e .
|
||||
|
||||
# Install required dependencies for use with torch
|
||||
pip install -e ".[torch]"
|
||||
|
||||
# To install all dependencies to develop, run tests, etc.
|
||||
pip install -e ".[dev]"
|
||||
```
|
||||
|
||||
## Running examples and tests
|
||||
[Tests](./test/) use [pytest](https://docs.pytest.org/en/stable/). An example of running a test is:
|
||||
```bash
|
||||
pytest test/integration/test_gemm.py
|
||||
```
|
||||
|
||||
[Examples](./examples/) are [Jupyter notebooks](https://jupyter.org/). They can be launched via:
|
||||
```bash
|
||||
cd examples
|
||||
jupyter notebook
|
||||
```
|
||||
|
||||
## Requirements, compatibility, and current support
|
||||
Please see [pyproject.toml](pyproject.toml) for dependencies.
|
||||
|
||||
**Compatibility:** CUTLASS API has the same compatibility requirements as the CUTLASS project.
|
||||
See CUTLASS's [Compatibility section](https://github.com/NVIDIA/cutlass?tab=readme-ov-file#compatibility).
|
||||
|
||||
### Current support
|
||||
|
||||
* Dense GEMM: `out = A @ B`
|
||||
- Compute capabilities: 100, 103
|
||||
- Input precisions (A and B must be of same type): F16, BF16, TF32, INT8
|
||||
- Output precisions: F32, F16, BF16, INT32
|
||||
- Epilogue operations:
|
||||
- Auxiliary load of tensors equal in rank and shape of `out`
|
||||
- Auxiliary store of tensors equal in rank and shape of `out`
|
||||
- Auxiliary load of scalar
|
||||
- Tensor-tensor elementwise or tensor-scalar addition, multiplication, subtraction, division
|
||||
- Elementwise tensor exponent, relu, sigmoid, tanh
|
||||
|
||||
* Planned additions
|
||||
* Block-scaled GEMMs
|
||||
* Grouped GEMMs
|
||||
* Additional epilogue operations
|
||||
* Reductions
|
||||
* Row/column broadcasts
|
||||
* SM90 kernels
|
||||
* Convolutions
|
||||
71
python/cutlass_api/cutlass_api/__init__.py
Normal file
71
python/cutlass_api/cutlass_api/__init__.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
__version__ = "0.1.0a1"
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
from cutlass_api import fusion
|
||||
from cutlass_api.arguments import RuntimeArguments
|
||||
from cutlass_api.kernel import Kernel
|
||||
from cutlass_api.manifest import Manifest
|
||||
from cutlass_api.metadata import KernelMetadata
|
||||
from cutlass_api.status import Status
|
||||
|
||||
|
||||
def get_kernels(
|
||||
args: RuntimeArguments = None,
|
||||
metadata_filter: Callable[[KernelMetadata], bool] | None = None,
|
||||
cc: int = None,
|
||||
providers: list[str] = None,
|
||||
) -> list[Kernel]:
|
||||
"""
|
||||
Get the kernels that match the given arguments, metadata filter, and compute capability.
|
||||
|
||||
:param args: the arguments of the kernel
|
||||
:type args: RuntimeArguments
|
||||
:param metadata_filter: a boolean function that takes in KernelMetadata and returns whether
|
||||
a Kernel from this metadata should be included
|
||||
:type metadata_filter: Callable[[KernelMetadata], bool]
|
||||
:param cc: the compute capability
|
||||
:type cc: int
|
||||
:param providers: the providers to use
|
||||
:type providers: list[str]
|
||||
|
||||
:return: the kernels that match the given arguments, metadata filter, and compute capability
|
||||
:rtype: list[Kernel]
|
||||
"""
|
||||
return Manifest.get_kernels(args, metadata_filter, cc, providers)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Manifest",
|
||||
"Status",
|
||||
"fusion",
|
||||
"get_kernels",
|
||||
]
|
||||
363
python/cutlass_api/cutlass_api/arguments.py
Normal file
363
python/cutlass_api/cutlass_api/arguments.py
Normal file
@@ -0,0 +1,363 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field, fields
|
||||
from typing import TYPE_CHECKING, get_type_hints
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
from cutlass_api.metadata import (
|
||||
ElementwiseOperandsMetadata,
|
||||
GemmOperandsMetadata,
|
||||
)
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
|
||||
from cutlass_api.fusion import EmptyTensor, trace, trace_in_out
|
||||
from cutlass_api.fusion.library import LayoutType
|
||||
from cutlass_api.typing import NumericLike, TensorLike
|
||||
from cutlass_api.utils import (
|
||||
TensorWrapper,
|
||||
add_batch_mode,
|
||||
is_torch_tensor,
|
||||
to_cutlass_type,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PerformanceControls:
|
||||
pass
|
||||
|
||||
|
||||
class EpilogueArguments:
|
||||
def __init__(self, epilogue_fn: Callable | str | None = None, **kwargs):
|
||||
"""
|
||||
Encapsulation of the epilogue function and its arguments needed to
|
||||
determine the functional operation of an epilogue pattern.
|
||||
|
||||
To support flexible definition of epilogues, `EpilogueArguments` is
|
||||
defined generically as taking in an `epilogue_fn` and additional `kwargs`.
|
||||
|
||||
Under the hood, the AST for `epilogue_fn` is parsed to determine the
|
||||
operands and outputs of the epilogue. `kwargs` must contain Tensors or scalars
|
||||
for all operands and outputs in the provided epilogue.
|
||||
|
||||
Structure of `epilogue_fn`
|
||||
--------------------------
|
||||
Epilogue fusion patterns are defined via Python functions that perform Tensor-level
|
||||
operations -- using a `torch.Tensor` (for example) resulting from matrix multiplication,
|
||||
the function must be able to compute the desired results of the epilogue.
|
||||
|
||||
The structure of these functions is as follows:
|
||||
```python
|
||||
def custom_epi_name(accum, *args) -> Union[TensorType, tuple[TensorType]]:
|
||||
'''
|
||||
:param accum: result of matrix multiplication, convolution, etc. before the epilogue
|
||||
:type accum: TensorType
|
||||
:param args: additional arguments to be used in the epilogue (e.g., aux tensors)
|
||||
:type args: list[Union[TensorType, ScalarType]]
|
||||
|
||||
:returns: at least one tensor resulting from the operation of the epilogue
|
||||
:rtype: Union[TensorType, tuple[TensorType]]
|
||||
'''
|
||||
# Do some compute
|
||||
return D # and potentially other values
|
||||
```
|
||||
|
||||
`epilogue_fn` must be a Python function or strign representation of a Python function
|
||||
that **must** satisfy the following constraints:
|
||||
- Take in a first positional argument named `accum` that represents the result
|
||||
of operation just before the epilogue is to be performed. For example, in a
|
||||
GEMM, `accum = A @ B`.
|
||||
- Return at least one tensor that results from computing the epilogue.
|
||||
Currently, the return list must contain at least one output named `D`,
|
||||
though this constraint may be loosened in the future.
|
||||
|
||||
Each additional argument following `accum` in the function definition must be
|
||||
a Tensor or scalar to be loaded. Each variable in the return statement represents
|
||||
a Tensor or scalar to be stored. The underlying implementation of the epilogue in
|
||||
the kernel will determine how operands are loaded and stored.
|
||||
|
||||
Structure of `kwargs`
|
||||
----------------------
|
||||
`kwargs` must contain Tensors or scalars for all operands and outputs in the provided epilogue.
|
||||
For example, with an epilogue of:
|
||||
```python
|
||||
def my_epi(accum, alpha, C, beta):
|
||||
F = (accum * alpha) + (C * beta)
|
||||
D = relu(F)
|
||||
return D, F
|
||||
```
|
||||
A user would need to construct epilogue arguments as follows:
|
||||
```python
|
||||
epi_args = EpilogueArguments(my_epi, alpha=..., C=..., beta=..., D=..., F=...)
|
||||
```
|
||||
|
||||
:param epilogue_fn: The epilogue function to be traced.
|
||||
:type epilogue_fn: Optional[Union[Callable, str]]
|
||||
:param kwargs: Additional keyword arguments consisting of the metadata
|
||||
for operands and outputs of the epilogue function.
|
||||
:type kwargs: dict
|
||||
"""
|
||||
epilogue_inputs: list[str] = []
|
||||
epilogue_outputs: list[str] = []
|
||||
if epilogue_fn is not None:
|
||||
# Parse the epilogue_fn AST to get the required input and output arguments
|
||||
epilogue_inputs, epilogue_outputs = trace_in_out(epilogue_fn)
|
||||
|
||||
# Get required input and output arguments from kwargs
|
||||
self.tensors = OrderedDict()
|
||||
for kw in epilogue_inputs + epilogue_outputs:
|
||||
if kw not in kwargs:
|
||||
raise ValueError(
|
||||
f"Argument {kw} is not provided in the kwargs of the EpilogueArguments constructor"
|
||||
)
|
||||
self.tensors[kw] = kwargs[kw]
|
||||
del kwargs[kw]
|
||||
|
||||
if len(kwargs) > 0:
|
||||
raise ValueError(
|
||||
f"Unexpected keyword arguments for epilogue: {kwargs.keys()}"
|
||||
)
|
||||
|
||||
self.epilogue_fn = epilogue_fn
|
||||
|
||||
@property
|
||||
def parameters(self) -> list[cute.Tensor | cutlass.Numeric]:
|
||||
"""Returns the list of input and output parameters of the epilogue"""
|
||||
return list(self.tensors.values())
|
||||
|
||||
@property
|
||||
def parameter_names(self) -> list[str]:
|
||||
"""Returns the list of names of the input and output parameters of the epilogue"""
|
||||
return list(self.tensors.keys())
|
||||
|
||||
def add_batch_modes(self):
|
||||
for name in self.tensors:
|
||||
self.tensors[name] = add_batch_mode(self.tensors[name])
|
||||
|
||||
def to_tensor_wrappers(self, permute: list[int] | None = None):
|
||||
"""Converts the input and output parameters of the epilogue to TensorWrappers"""
|
||||
for k, v in self.tensors.items():
|
||||
if is_torch_tensor(v):
|
||||
if permute is not None:
|
||||
v = v.permute(permute)
|
||||
|
||||
self.tensors[k] = TensorWrapper(v)
|
||||
|
||||
def trace(
|
||||
self, accumulator_shape: tuple[int, ...], accumulator_type: cutlass.Numeric
|
||||
):
|
||||
"""
|
||||
Traces the epilogue function and generates an internal representation of the epilogue.
|
||||
|
||||
:param accumulator_shape: The shape of the accumulator tensor. For example, for a GEMM, this would be the shape of the output tensor.
|
||||
:type accumulator_shape: tuple[int, ...]
|
||||
:param accumulator_type: The datatype of the accumulator tensor.
|
||||
:type accumulator_type: cutlass.Numeric
|
||||
"""
|
||||
accumulator = EmptyTensor(
|
||||
element=accumulator_type,
|
||||
shape=accumulator_shape,
|
||||
layout_tag=LayoutType.RowMajor,
|
||||
)
|
||||
tensors_for_tracing = {**self.tensors, "accum": accumulator}
|
||||
|
||||
# Parse the AST of the epilogue_fn again, this time with the set of required
|
||||
# tensors. This pass converts the epilogue into an internal representation and
|
||||
# performs a limited set of correctness checks (e.g., shape matches)
|
||||
#
|
||||
# Since all current providers are not based on C++ EVT, we do not need to convert
|
||||
# the DAG to a tree. If a provider that tightly matches the C++ EVT template structure,
|
||||
# this will need to be revisited.
|
||||
self.traced_epilogue = trace(
|
||||
self.epilogue_fn, tensors_for_tracing, requires_conversion_to_tree=False
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RuntimeArguments:
|
||||
"""
|
||||
Base class for container of all runtime operands and other runtime parameters needed
|
||||
by a kernel. This includes primary runtime operands to the operation, as well as
|
||||
any custom epilogue fusions and runtime performance knobs.
|
||||
|
||||
Subclasses map to an operation type (e.g., GEMM, elementwise). These subclasses have
|
||||
additional fields that are specific to the operation type.
|
||||
|
||||
:param epilogue: Optional custom epilogue fusion to be performed after the operation.
|
||||
:type epilogue: Optional[EpilogueArguments]
|
||||
:param performance: Optional performance controls to be used by the kernel.
|
||||
:type performance: Optional[PerformanceControls]
|
||||
"""
|
||||
|
||||
epilogue: EpilogueArguments | None = field(default=None, kw_only=True)
|
||||
performance: PerformanceControls | None = field(default=None, kw_only=True)
|
||||
|
||||
def _validate(self):
|
||||
"""
|
||||
Checks that the arguments are valid.
|
||||
This is run before all fields have been converted to TensorWrapper and cutlass.Numeric.
|
||||
"""
|
||||
|
||||
def __post_init__(self):
|
||||
self._validate()
|
||||
self._convert_to_internal_types()
|
||||
|
||||
def _convert_to_internal_types(self):
|
||||
"""
|
||||
Converts all fields to their internal types.
|
||||
"""
|
||||
type_hints = get_type_hints(type(self))
|
||||
for f in fields(self):
|
||||
hint = type_hints.get(f.name)
|
||||
value = getattr(self, f.name)
|
||||
|
||||
# Find all fields that are annotated as TensorLike,
|
||||
# and wrap them in TensorWrapper
|
||||
if hint is TensorLike:
|
||||
setattr(self, f.name, TensorWrapper(value))
|
||||
elif hint is NumericLike:
|
||||
setattr(self, f.name, to_cutlass_type(value))
|
||||
|
||||
|
||||
@dataclass
|
||||
class GemmArguments(RuntimeArguments):
|
||||
"""
|
||||
Arguments for a Generalized Matrix Multiplication (GEMM) operation: out = A @ B
|
||||
|
||||
The tensors must be all rank-3 or all rank-2.
|
||||
L: Number of batches
|
||||
M: Number of rows in A and out
|
||||
K: Number of columns in A and rows in B
|
||||
N: Number of columns in B and out
|
||||
|
||||
:param A: Input tensor A of shape (L, M, K) or (M, K)
|
||||
:type A: TensorWrapper
|
||||
:param B: Input tensor B of shape (L, K, N) or (K, N)
|
||||
:type B: TensorWrapper
|
||||
:param out: Output tensor C of shape (L, M, N) or (M, N)
|
||||
:type out: TensorWrapper
|
||||
:param accumulator_type: Data type of the accumulator
|
||||
:type accumulator_type: cutlass.Numeric
|
||||
"""
|
||||
|
||||
A: TensorLike
|
||||
B: TensorLike
|
||||
out: TensorLike
|
||||
accumulator_type: NumericLike
|
||||
|
||||
def to_operands_metadata(self) -> GemmOperandsMetadata:
|
||||
from cutlass_api.metadata import GemmOperandsMetadata
|
||||
|
||||
return GemmOperandsMetadata.from_args(self)
|
||||
|
||||
def _validate(self):
|
||||
"""
|
||||
Checks that the arguments are valid.
|
||||
"""
|
||||
if len(self.A.shape) < 2 or len(self.A.shape) > 3:
|
||||
raise ValueError(
|
||||
f"A must be a tensor of rank 2 or 3 (L=1, M, K), got {self.A.shape}"
|
||||
)
|
||||
if len(self.B.shape) < 2 or len(self.B.shape) > 3:
|
||||
raise ValueError(
|
||||
f"B must be a tensor of rank 2 or 3 (L=1, K, N), got {len(self.B.shape)}"
|
||||
)
|
||||
if len(self.out.shape) < 2 or len(self.out.shape) > 3:
|
||||
raise ValueError(
|
||||
f"out must be a tensor of rank 2 or 3 (L=1, M, N), got {len(self.out.shape)}"
|
||||
)
|
||||
if self.A.shape[-1] != self.B.shape[-2]:
|
||||
raise ValueError(
|
||||
f"A's K dimension ({self.A.shape[-1]}) must be equal to B's K dimension ({self.B.shape[-2]}). A shape (L, M, K): {self.A.shape}, B shape (L, K, N): {self.B.shape}"
|
||||
)
|
||||
if self.out.shape[-2] != self.A.shape[-2]:
|
||||
raise ValueError(
|
||||
f"out's M dimension ({self.out.shape[-2]}) must be equal to A's M dimension ({self.A.shape[-2]}). A shape (L, M, K): {self.A.shape}, out shape (L, M, N): {self.out.shape}"
|
||||
)
|
||||
if self.out.shape[-1] != self.B.shape[-1]:
|
||||
raise ValueError(
|
||||
f"out's N dimension ({self.out.shape[-1]}) must be equal to B's N dimension ({self.B.shape[-1]}). B shape (L, K, N): {self.B.shape}, out shape (L, M, N): {self.out.shape}"
|
||||
)
|
||||
if self.A.shape[:-2] != self.B.shape[:-2]:
|
||||
raise ValueError(
|
||||
f"A & B must have the same rank and batch dimension (if any). A shape (L, M, K): {self.A.shape}, B shape (L, K, N): {self.B.shape}"
|
||||
)
|
||||
if self.out.shape[:-2] != self.A.shape[:-2]:
|
||||
raise ValueError(
|
||||
f"out & A must have the same rank and batch dimension (if any). out shape (L, M, N): {self.out.shape}, A shape (L, M, K): {self.A.shape}"
|
||||
)
|
||||
|
||||
if isinstance(self.epilogue, EpilogueArguments):
|
||||
L = self.A.shape[0] if len(self.A.shape) == 3 else 1
|
||||
M, N = self.A.shape[-2], self.B.shape[-1]
|
||||
accum_shape = (L, M, N)
|
||||
self.epilogue.trace(accum_shape, self.accumulator_type)
|
||||
self.epilogue.to_tensor_wrappers()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ElementwiseArguments(RuntimeArguments):
|
||||
"""
|
||||
Arguments needed for an elementwise operation.
|
||||
|
||||
:param A: The input tensor A.
|
||||
:type A: TensorLike
|
||||
:param B: The input tensor B.
|
||||
:type B: TensorLike
|
||||
:param out: The output tensor.
|
||||
:type out: TensorLike
|
||||
"""
|
||||
|
||||
A: TensorLike
|
||||
B: TensorLike
|
||||
out: TensorLike
|
||||
|
||||
def to_operands_metadata(self) -> ElementwiseOperandsMetadata:
|
||||
from cutlass_api.metadata import ElementwiseOperandsMetadata
|
||||
|
||||
return ElementwiseOperandsMetadata.from_args(self)
|
||||
|
||||
def _validate(self):
|
||||
"""
|
||||
Checks that the arguments are valid.
|
||||
"""
|
||||
if self.A.shape != self.B.shape:
|
||||
raise ValueError(
|
||||
f"A.shape ({self.A.shape}) must be equal to B.shape ({self.B.shape})"
|
||||
)
|
||||
if self.out.shape != self.A.shape:
|
||||
raise ValueError(
|
||||
f"out.shape ({self.out.shape}) must be equal to A.shape ({self.A.shape})"
|
||||
)
|
||||
47
python/cutlass_api/cutlass_api/artifact.py
Normal file
47
python/cutlass_api/cutlass_api/artifact.py
Normal file
@@ -0,0 +1,47 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompiledArtifact:
|
||||
"""
|
||||
Lightweight wrapper over the result of compiling a kernel (e.g., via
|
||||
`cute.compile` when using CuTe DSL).
|
||||
|
||||
:param compiled_obj: The result of compiling the kernel.
|
||||
:type compiled_obj: Any
|
||||
:param kernel_obj: The Kernel object on which `.compile()` was called to generate this artifact.
|
||||
:type kernel_obj: cutlass_api.kernel.Kernel
|
||||
"""
|
||||
|
||||
compiled_obj: Any
|
||||
kernel_obj: "cutlass_api.kernel.Kernel"
|
||||
83
python/cutlass_api/cutlass_api/config.py
Normal file
83
python/cutlass_api/cutlass_api/config.py
Normal file
@@ -0,0 +1,83 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
from importlib.util import find_spec
|
||||
|
||||
|
||||
class GlobalOptions:
|
||||
"""
|
||||
Singleton class that configures global options for CUTLASS API.
|
||||
|
||||
This can be used to enable or disable certain features of the API. For example,
|
||||
the user can enable TVM-FFI support to allow for the use of framework-level tensors
|
||||
at run time.
|
||||
|
||||
Example:
|
||||
```python
|
||||
GlobalOptions().use_tvm_ffi = True
|
||||
```
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
|
||||
def __new__(cls):
|
||||
"""
|
||||
Create a new singleton instance of the GlobalOptions class only once.
|
||||
"""
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
|
||||
has_tvm_ffi = find_spec("tvm_ffi") is not None
|
||||
cls._instance._options = {
|
||||
"use_tvm_ffi": has_tvm_ffi,
|
||||
}
|
||||
return cls._instance
|
||||
|
||||
@property
|
||||
def use_tvm_ffi(self) -> bool:
|
||||
"""
|
||||
Check if TVM FFI support is enabled.
|
||||
Default: True if `tvm_ffi` is installed.
|
||||
|
||||
When enabled, conversions from DLPack compatible tensors to cute.Tensor is via TVM FFI.
|
||||
Additionally, invoking the compiled kernel happens via TVM FFI.
|
||||
Both can offer significant (3x-10x) speedups.
|
||||
|
||||
Dependencies:
|
||||
- Required: `tvm_ffi` (pip install apache-tvm-ffi)
|
||||
- Optional: `torch_c_dlpack_ext` (pip install torch-c-dlpack-ext)
|
||||
"""
|
||||
return self._options["use_tvm_ffi"]
|
||||
|
||||
@use_tvm_ffi.setter
|
||||
def use_tvm_ffi(self, value: bool) -> None:
|
||||
if value and not find_spec("tvm_ffi"):
|
||||
raise ImportError(
|
||||
"TVM FFI is not installed, please install it via `pip install apache-tvm-ffi`."
|
||||
)
|
||||
self._options["use_tvm_ffi"] = value
|
||||
127
python/cutlass_api/cutlass_api/fusion/__init__.py
Normal file
127
python/cutlass_api/cutlass_api/fusion/__init__.py
Normal file
@@ -0,0 +1,127 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
import ast
|
||||
import textwrap
|
||||
from typing import Callable, Union
|
||||
|
||||
from cutlass_api.fusion.ir.tensor import Tensor as EmptyTensor
|
||||
from cutlass_api.fusion.frontend import PythonASTFrontend, PythonASTInOutProcessor
|
||||
|
||||
|
||||
def trace_in_out(fn: Union[Callable, str]) -> tuple[list[str], list[str]]:
|
||||
"""
|
||||
Trace the input function and returns the names of inputs and outputs
|
||||
|
||||
:param fn: The function to trace.
|
||||
:type fn: Union[Callable, str]
|
||||
:return: A tuple of lists of input and output names.
|
||||
:rtype: tuple[list[str], list[str]]
|
||||
"""
|
||||
processor = PythonASTInOutProcessor()
|
||||
return processor.trace(fn)
|
||||
|
||||
|
||||
def trace(fn, example_tensors, **kwargs):
|
||||
"""
|
||||
Trace `fn(**example_tensors)` and generates epilogue visitor
|
||||
|
||||
:param fn or str: Python callable or string of the epilogue function
|
||||
:param example_tensors: example inputs for fn
|
||||
:type example_tensors: dict
|
||||
|
||||
.. hightlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
# Define epilogue function as Python callable
|
||||
def example_fn(accum, C, alpha, beta, gamma):
|
||||
D = ((accum + C) * alpha - gamma) / beta
|
||||
return D
|
||||
|
||||
# Define the example tensors
|
||||
example_inputs = {
|
||||
"accum": torch.empty(size=(6, 512, 512), dtype=torch.float16, device="cuda"),
|
||||
"C": torch.empty(size=(6, 512, 512), dtype=torch.float16, device="cuda"),
|
||||
"alpha": 1.5,
|
||||
"beta": 0.5,
|
||||
"gamma": 2.5,
|
||||
"D": torch.empty(size=(6, 512, 512), dtype=torch.float16, device="cuda")
|
||||
}
|
||||
|
||||
# Generate the epilogue functor
|
||||
epilogue_visitor = cutlass_api.fusion.trace(example_fn, example_inputs)
|
||||
"""
|
||||
if callable(fn):
|
||||
|
||||
class EpilogueFunctor(PythonASTFrontend):
|
||||
def __init__(self, cc=None, **kwargs):
|
||||
# Since we are currently only using the trace() method for generating an
|
||||
# intermiediate representation (which is not CC specific), we can hardcode the cc to 100
|
||||
if not cc:
|
||||
cc = 100
|
||||
super().__init__(cc, **kwargs)
|
||||
|
||||
pass
|
||||
|
||||
setattr(EpilogueFunctor, "__call__", staticmethod(fn))
|
||||
|
||||
epilogue_functor = EpilogueFunctor(**kwargs)
|
||||
epilogue_functor.trace(example_tensors)
|
||||
return epilogue_functor
|
||||
elif isinstance(fn, str):
|
||||
|
||||
class EpilogueFunctor(PythonASTFrontend):
|
||||
def __init__(self, cc=None, **kwargs):
|
||||
self.source = textwrap.dedent(fn)
|
||||
# Since we are currently only using the trace() method for generating an
|
||||
# intermiediate representation (which is not CC specific), we can hardcode the cc to 100
|
||||
if not cc:
|
||||
cc = 100
|
||||
super().__init__(cc, **kwargs)
|
||||
|
||||
def parse(self, example_inputs) -> None:
|
||||
self.example_inputs = example_inputs
|
||||
self.ast = ast.parse(self.source)
|
||||
self.visit(self.ast)
|
||||
|
||||
epilogue_functor = EpilogueFunctor(**kwargs)
|
||||
epilogue_functor.trace(example_tensors)
|
||||
return epilogue_functor
|
||||
else:
|
||||
raise NotImplementedError("Expect a callable Python function")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"EmptyTensor",
|
||||
"trace_in_out",
|
||||
"trace",
|
||||
]
|
||||
237
python/cutlass_api/cutlass_api/fusion/activation.py
Normal file
237
python/cutlass_api/cutlass_api/fusion/activation.py
Normal file
@@ -0,0 +1,237 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Helpers for common activation functions
|
||||
"""
|
||||
|
||||
import ctypes
|
||||
|
||||
from cutlass_api.fusion.ir.c_types import dtype2ctype, to_ctype_value
|
||||
from cutlass_api.fusion.library import ActivationOp
|
||||
from cutlass_api.utils import (
|
||||
is_torch_available,
|
||||
is_numpy_available,
|
||||
is_numpy_tensor,
|
||||
is_torch_tensor,
|
||||
)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
if is_numpy_available():
|
||||
import numpy as np
|
||||
|
||||
|
||||
class ActivationFunctor:
|
||||
"""
|
||||
Base class for frequently used activation functions
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def numpy(x):
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def epilogue_output_op(element_epilogue):
|
||||
c_element_epilogue = dtype2ctype[element_epilogue]
|
||||
|
||||
class _EpilogueOutputOpParams(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("alpha", c_element_epilogue),
|
||||
("beta", c_element_epilogue),
|
||||
("alpha_ptr", ctypes.c_void_p),
|
||||
("beta_ptr", ctypes.c_void_p),
|
||||
]
|
||||
|
||||
def __init__(self, alpha, beta, *args) -> None:
|
||||
self.alpha = to_ctype_value(alpha, element_epilogue)
|
||||
self.beta = to_ctype_value(beta, element_epilogue)
|
||||
|
||||
return _EpilogueOutputOpParams
|
||||
|
||||
|
||||
class ActivationMeta(type):
|
||||
@classmethod
|
||||
def __call__(cls, x, *args):
|
||||
if is_numpy_tensor(x):
|
||||
return cls.numpy(x, *args)
|
||||
elif is_torch_tensor(x):
|
||||
return cls.torch(x, *args)
|
||||
else:
|
||||
raise NotImplementedError("Unsupported tensor type")
|
||||
|
||||
@classmethod
|
||||
def numpy(cls, *args):
|
||||
raise NotImplementedError(
|
||||
f"Numpy reference for {cls.__name__[:-4]} is not implemented."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def torch(cls, *args):
|
||||
raise NotImplementedError(
|
||||
f"PyTorch reference for {cls.__name__[:-4]} is not implemented."
|
||||
)
|
||||
|
||||
|
||||
##############################################################################
|
||||
# identity operator
|
||||
class identityMeta(ActivationMeta):
|
||||
@classmethod
|
||||
def numpy(cls, x):
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def torch(cls, x):
|
||||
return x
|
||||
|
||||
|
||||
class identity(ActivationFunctor, metaclass=identityMeta):
|
||||
binding_type = ActivationOp.Identity
|
||||
|
||||
|
||||
##############################################################################
|
||||
# ReLu operator
|
||||
class reluMeta(ActivationMeta):
|
||||
@classmethod
|
||||
def numpy(cls, x):
|
||||
return np.where(x > 0, x, 0)
|
||||
|
||||
@classmethod
|
||||
def torch(cls, x):
|
||||
return F.relu(x)
|
||||
|
||||
|
||||
class relu(ActivationFunctor, metaclass=reluMeta):
|
||||
binding_type = ActivationOp.ReLU
|
||||
|
||||
|
||||
##############################################################################
|
||||
# Leaky ReLu operator
|
||||
class leakyReLUMeta(ActivationMeta):
|
||||
@classmethod
|
||||
def numpy(cls, x, leaky_alpha):
|
||||
return np.maximum(x, 0) + np.minimum(x, 0) * leaky_alpha
|
||||
|
||||
@classmethod
|
||||
def torch(cls, x, leaky_alpha):
|
||||
return F.leaky_relu(x, leaky_alpha)
|
||||
|
||||
|
||||
class leaky_relu(ActivationFunctor, metaclass=leakyReLUMeta):
|
||||
binding_type = ActivationOp.LeakyReLU
|
||||
|
||||
|
||||
##############################################################################
|
||||
# Tanh operator
|
||||
class tanhMeta(ActivationMeta):
|
||||
@classmethod
|
||||
def numpy(cls, x):
|
||||
return np.tanh(x)
|
||||
|
||||
@classmethod
|
||||
def torch(cls, x):
|
||||
return torch.tanh(x)
|
||||
|
||||
|
||||
class tanh(ActivationFunctor, metaclass=tanhMeta):
|
||||
binding_type = ActivationOp.Tanh
|
||||
|
||||
|
||||
##############################################################################
|
||||
# Sigmoid operator
|
||||
class sigmoidMeta(ActivationMeta):
|
||||
@classmethod
|
||||
def numpy(cls, x):
|
||||
return 1.0 / (1.0 + np.exp(-x))
|
||||
|
||||
@classmethod
|
||||
def torch(cls, x):
|
||||
return F.sigmoid(x)
|
||||
|
||||
|
||||
class sigmoid(ActivationFunctor, metaclass=sigmoidMeta):
|
||||
binding_type = ActivationOp.Sigmoid
|
||||
|
||||
|
||||
##############################################################################
|
||||
# SiLu operator
|
||||
class siluMeta(ActivationMeta):
|
||||
@classmethod
|
||||
def numpy(cls, x):
|
||||
return x * sigmoidMeta.numpy(x)
|
||||
|
||||
@classmethod
|
||||
def torch(cls, x):
|
||||
return F.silu(x)
|
||||
|
||||
|
||||
class silu(ActivationFunctor, metaclass=siluMeta):
|
||||
binding_type = ActivationOp.SiLU
|
||||
|
||||
|
||||
##############################################################################
|
||||
# Hardswish operator
|
||||
class hardswishMeta(ActivationMeta):
|
||||
@classmethod
|
||||
def numpy(cls, x):
|
||||
relu6 = np.minimum(np.maximum(x + 3.0, 0), 6.0)
|
||||
return x * relu6 / 6.0
|
||||
|
||||
@classmethod
|
||||
def torch(cls, x):
|
||||
return F.hardswish(x)
|
||||
|
||||
|
||||
class hardswish(ActivationFunctor, metaclass=hardswishMeta):
|
||||
binding_type = ActivationOp.HardSwish
|
||||
|
||||
|
||||
##############################################################################
|
||||
# GELU operator
|
||||
class geluMeta(ActivationMeta):
|
||||
@classmethod
|
||||
def numpy(cls, x):
|
||||
from scipy.special import erf
|
||||
|
||||
return 0.5 * x * (1 + erf(x / np.sqrt(2.0)))
|
||||
|
||||
@classmethod
|
||||
def torch(cls, x):
|
||||
return F.gelu(x)
|
||||
|
||||
|
||||
class gelu(ActivationFunctor, metaclass=geluMeta):
|
||||
binding_type = ActivationOp.Gelu
|
||||
47
python/cutlass_api/cutlass_api/fusion/backend/__init__.py
Normal file
47
python/cutlass_api/cutlass_api/fusion/backend/__init__.py
Normal file
@@ -0,0 +1,47 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from cutlass_api.fusion.backend.sm80_emitter import Sm80Emitter
|
||||
import cutlass_api.fusion.backend.sm80_nodes as sm80_nodes
|
||||
from cutlass_api.fusion.backend.sm90_emitter import Sm90Emitter
|
||||
import cutlass_api.fusion.backend.sm90_nodes as sm90_nodes
|
||||
from cutlass_api.fusion.backend.sm100_emitter import Sm100Emitter
|
||||
import cutlass_api.fusion.backend.sm100_nodes as sm100_nodes
|
||||
|
||||
__all__ = [
|
||||
"Sm80Emitter",
|
||||
"sm80_nodes",
|
||||
"Sm90Emitter",
|
||||
"sm90_nodes",
|
||||
"Sm100Emitter",
|
||||
"sm100_nodes",
|
||||
]
|
||||
165
python/cutlass_api/cutlass_api/fusion/backend/emitter_base.py
Normal file
165
python/cutlass_api/cutlass_api/fusion/backend/emitter_base.py
Normal file
@@ -0,0 +1,165 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Base class for Epilogue Visitor Emitter
|
||||
"""
|
||||
|
||||
from cutlass_api.fusion.library import DataTypeTag
|
||||
from cutlass_api.fusion.ir import TopoVisitorNode, DAGIR
|
||||
|
||||
|
||||
class FusionCallbacks:
|
||||
def __init__(self, dag_ir: DAGIR, cc: int, emit_CD=True) -> None:
|
||||
"""
|
||||
Emit the EVT fusion callbacks
|
||||
:param dag_ir: the DAG IR holding the epilogue visitor
|
||||
:param cc: compute capability
|
||||
:param emit_CD: whether to emit nodes C & D as a part of the fusion callbacks
|
||||
For Sm90, set emit_CD=False, as Tensor C & D are hardcoded in the collective API
|
||||
so that their shared memory can be explicitly reused
|
||||
For Sm89, set emit_CD=True as they are treated as normal AuxLoad & AuxStore nodes.
|
||||
"""
|
||||
self.dag_ir = dag_ir
|
||||
self.emit_CD = emit_CD
|
||||
self.cc = cc
|
||||
self.evt_cc = 90 if cc >= 90 else cc
|
||||
if self.cc < 90:
|
||||
self.namespace = "threadblock"
|
||||
else:
|
||||
self.namespace = "fusion"
|
||||
|
||||
#
|
||||
# Helper functions
|
||||
#
|
||||
|
||||
def get_visitor_name(self, node: str):
|
||||
"""
|
||||
Get the visitor name
|
||||
"""
|
||||
meta = self.dag_ir.get_node_meta(node)
|
||||
if not isinstance(meta, TopoVisitorNode) and self.dag_ir.in_degree(node) > 0:
|
||||
return f"EVT{meta.name_camel}"
|
||||
else:
|
||||
return meta.name_camel
|
||||
|
||||
def emit(self):
|
||||
node_metas = self.dag_ir.node_metas_topological_order()
|
||||
epilogue_str = ""
|
||||
# Step 1: emit individual node type decl
|
||||
# emit the EVT & DAG connector
|
||||
for meta in node_metas:
|
||||
if not meta.disabled:
|
||||
epilogue_str += self.emit_node(meta)
|
||||
if not self.emit_CD and meta.name == "D":
|
||||
continue
|
||||
if isinstance(meta, TopoVisitorNode):
|
||||
epilogue_str += self.emit_dag(meta)
|
||||
else:
|
||||
epilogue_str += self.emit_evt(meta)
|
||||
|
||||
# Step 2: post-processing & get callback name
|
||||
if not self.emit_CD:
|
||||
if not self.dag_ir.has_node("C"):
|
||||
epilogue_str += "using ElementC = void;\nusing StrideC = StrideD;\n"
|
||||
output_node = self.dag_ir.get_all_inputs("D")[0]
|
||||
# The callback is the src of node D
|
||||
callback_name = self.get_visitor_name(output_node)
|
||||
else:
|
||||
# The callback is the last node in the topological order
|
||||
callback_name = self.get_visitor_name(node_metas[-1].name)
|
||||
return epilogue_str, callback_name
|
||||
|
||||
def emit_evt(self, node):
|
||||
if self.dag_ir.in_degree(node.name) == 0:
|
||||
return ""
|
||||
|
||||
evt_tmp = f"""
|
||||
using EVT{node.name_camel} = cutlass::epilogue::{self.namespace}::Sm{self.evt_cc}EVT<
|
||||
{node.name_camel},
|
||||
"""
|
||||
sorted_children = self.dag_ir.get_all_inputs(node.name)
|
||||
evt_node_strs = [
|
||||
f" {self.get_visitor_name(child_name)}" for child_name in sorted_children
|
||||
]
|
||||
evt_tmp += ",\n".join(evt_node_strs) + ">;\n"
|
||||
|
||||
return evt_tmp
|
||||
|
||||
def emit_dag(self, node):
|
||||
subgraph = node.subgraph
|
||||
subgraph_nodes = subgraph.nodes_topological_order()
|
||||
# Emit the Edge Tuple
|
||||
edge_tuples = "cute::tuple<\n"
|
||||
for n in subgraph_nodes[:-1]:
|
||||
in_edges = subgraph.in_edges(n)
|
||||
edge_weights = [
|
||||
subgraph.get_edge_weight(edge[0], edge[1]) for edge in in_edges
|
||||
]
|
||||
sorted_children = [
|
||||
edge[0] for _, edge in sorted(zip(edge_weights, in_edges))
|
||||
]
|
||||
edge_tuple = " cute::seq<"
|
||||
edge_str = [str(subgraph_nodes.index(child)) for child in sorted_children]
|
||||
edge_tuple += ", ".join(edge_str) + ">,\n"
|
||||
|
||||
edge_tuples += edge_tuple
|
||||
edge_tuples += " >"
|
||||
|
||||
# Emit the node list
|
||||
dag_nodes = ""
|
||||
dag_node_strs = []
|
||||
for n in subgraph_nodes[:-1]:
|
||||
n_meta = subgraph.get_node_meta(n)
|
||||
if n_meta.disabled:
|
||||
dag_node_strs.append(f" {self.get_visitor_name(n)}")
|
||||
else:
|
||||
dag_node_strs.append(f" {n_meta.name_camel}")
|
||||
dag_nodes = ",\n".join(dag_node_strs)
|
||||
|
||||
return f"""
|
||||
using {node.name_camel} = cutlass::epilogue::{self.namespace}::Sm{self.evt_cc}TopologicalVisitor<
|
||||
{DataTypeTag[node.subgraph.element_compute]},
|
||||
{edge_tuples},
|
||||
{dag_nodes}
|
||||
>;
|
||||
"""
|
||||
|
||||
def emit_node(self, node):
|
||||
if isinstance(node, TopoVisitorNode):
|
||||
emission = ""
|
||||
for meta in node.subgraph.node_metas_topological_order():
|
||||
if not meta.disabled:
|
||||
emission += self.emit_node(meta)
|
||||
return emission
|
||||
else:
|
||||
return node.underlying_impl.type_decl
|
||||
151
python/cutlass_api/cutlass_api/fusion/backend/sm100_emitter.py
Normal file
151
python/cutlass_api/cutlass_api/fusion/backend/sm100_emitter.py
Normal file
@@ -0,0 +1,151 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Emitter for Sm100 Epilogue Visitor
|
||||
"""
|
||||
|
||||
from cutlass_api.fusion.library import (
|
||||
DataType,
|
||||
DataTypeTag,
|
||||
EpilogueScheduleTag,
|
||||
KernelScheduleSuffixes,
|
||||
OpcodeClassTag,
|
||||
)
|
||||
from cutlass_api.fusion.backend.emitter_base import FusionCallbacks
|
||||
|
||||
|
||||
def to_blackwell_threadblock_shape(tile_description, cluster_shape, kernel_schedule):
|
||||
blackwell_threadblock_shape = tile_description.threadblock_shape
|
||||
is_2sm = (
|
||||
False
|
||||
if kernel_schedule is None
|
||||
else ("2sm" in KernelScheduleSuffixes[kernel_schedule])
|
||||
)
|
||||
if cluster_shape[0] > 0:
|
||||
blackwell_threadblock_shape = [
|
||||
tile_description.threadblock_shape[0] // cluster_shape[0],
|
||||
tile_description.threadblock_shape[1] // cluster_shape[1],
|
||||
tile_description.threadblock_shape[2] // cluster_shape[2],
|
||||
]
|
||||
if is_2sm:
|
||||
blackwell_threadblock_shape[0] *= 2
|
||||
else:
|
||||
blackwell_threadblock_shape = (
|
||||
tile_description.math_instruction.instruction_shape
|
||||
)
|
||||
return blackwell_threadblock_shape, is_2sm
|
||||
|
||||
|
||||
class Sm100CollectiveEpilogue:
|
||||
def __init__(
|
||||
self,
|
||||
tile_description,
|
||||
kernel_schedule,
|
||||
epilogue_schedule,
|
||||
element_accumulator,
|
||||
element_d,
|
||||
fusion_callbacks,
|
||||
) -> None:
|
||||
self.cta_tile_mnk, _ = to_blackwell_threadblock_shape(
|
||||
tile_description, tile_description.cluster_shape, kernel_schedule
|
||||
)
|
||||
self.element_accumulator = element_accumulator
|
||||
if fusion_callbacks.dag_ir.has_node("C"):
|
||||
self.element_c = fusion_callbacks.dag_ir.get_node_meta("C").element
|
||||
else:
|
||||
self.element_c = DataType.void
|
||||
self.element_d = element_d
|
||||
self.schedule = epilogue_schedule
|
||||
self.fusion_callbacks = fusion_callbacks
|
||||
self.opclass = tile_description.math_instruction.opcode_class
|
||||
|
||||
@property
|
||||
def CtaTileMNK(self) -> str:
|
||||
"""
|
||||
The threadblock shape
|
||||
"""
|
||||
return f"cute::Shape<_{self.cta_tile_mnk[0]}, _{self.cta_tile_mnk[1]}, _{self.cta_tile_mnk[2]}>"
|
||||
|
||||
@property
|
||||
def EpilogueTileType(self) -> str:
|
||||
"""
|
||||
The epilogue tile type
|
||||
"""
|
||||
return "cutlass::epilogue::collective::EpilogueTileAuto"
|
||||
|
||||
@property
|
||||
def Schedule(self) -> str:
|
||||
return EpilogueScheduleTag[self.schedule]
|
||||
|
||||
def emit(self):
|
||||
stride_D_str = self.fusion_callbacks.dag_ir.get_node_meta(
|
||||
"D"
|
||||
).underlying_impl.stride_mnl
|
||||
stride_C_str = stride_D_str
|
||||
if self.fusion_callbacks.dag_ir.has_node("C"):
|
||||
stride_C_str = self.fusion_callbacks.dag_ir.get_node_meta(
|
||||
"C"
|
||||
).underlying_impl.stride_mnl
|
||||
|
||||
callback_decl, callback_name = self.fusion_callbacks.emit()
|
||||
return (
|
||||
callback_name,
|
||||
f"""
|
||||
using EpilogueDescriptor = cutlass::epilogue::collective::detail::Sm100EpilogueDescriptor<
|
||||
{OpcodeClassTag[self.opclass]},
|
||||
{self.CtaTileMNK}, {self.EpilogueTileType},
|
||||
{DataTypeTag[self.element_accumulator]}, {DataTypeTag[self.element_c]}, {DataTypeTag[self.element_d]},
|
||||
{self.Schedule}, {stride_C_str}, {stride_D_str},
|
||||
false /* IsPerColScaleSupported */,
|
||||
false /* IsBlockScaleSupported */
|
||||
>;
|
||||
{callback_decl}
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
class Sm100Emitter:
|
||||
def __init__(self, operation, graph) -> None:
|
||||
fusion_callbacks = FusionCallbacks(graph, cc=100, emit_CD=False)
|
||||
|
||||
self.collective_epilogue = Sm100CollectiveEpilogue(
|
||||
tile_description=operation.tile_description,
|
||||
kernel_schedule=operation.tile_description.kernel_schedule,
|
||||
epilogue_schedule=operation.tile_description.epilogue_schedule,
|
||||
element_accumulator=operation.tile_description.math_instruction.element_accumulator,
|
||||
element_d=fusion_callbacks.dag_ir.get_node_meta("D").element,
|
||||
fusion_callbacks=fusion_callbacks,
|
||||
)
|
||||
|
||||
def emit(self):
|
||||
return self.collective_epilogue.emit()
|
||||
140
python/cutlass_api/cutlass_api/fusion/backend/sm100_nodes.py
Normal file
140
python/cutlass_api/cutlass_api/fusion/backend/sm100_nodes.py
Normal file
@@ -0,0 +1,140 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
|
||||
from cutlass_api.fusion.ir import AuxLoadImpl, AuxStoreImpl
|
||||
from cutlass_api.fusion.library import DataTypeSize, DataTypeTag, FloatRoundStyleTag
|
||||
import cutlass_api.fusion.backend.sm90_nodes as sm90_nodes
|
||||
from cutlass_api.fusion.pycute import product
|
||||
|
||||
|
||||
Sm100AccumulatorImpl = sm90_nodes.Sm90AccumulatorImpl
|
||||
Sm100LoadSrcImpl = sm90_nodes.Sm90LoadSrcImpl
|
||||
Sm100ScalarBroadcastImpl = sm90_nodes.Sm90ScalarBroadcastImpl
|
||||
Sm100RowBroadcastImpl = sm90_nodes.Sm90RowBroadcastImpl
|
||||
Sm100ColumnBroadcastImpl = sm90_nodes.Sm90ColumnBroadcastImpl
|
||||
Sm100ComputeImpl = sm90_nodes.Sm90ComputeImpl
|
||||
Sm100StoreDImpl = sm90_nodes.Sm90StoreDImpl
|
||||
Sm100ColumnReductionImpl = sm90_nodes.Sm90ColumnReductionImpl
|
||||
Sm100RowReductionImpl = sm90_nodes.Sm90RowReductionImpl
|
||||
Sm100ScalarReductionImpl = sm90_nodes.Sm90ScalarReductionImpl
|
||||
|
||||
|
||||
class Sm100AuxLoadImpl(AuxLoadImpl):
|
||||
@property
|
||||
def descriptor(self) -> str:
|
||||
"""
|
||||
Descriptor for Aux Load
|
||||
"""
|
||||
return f"{self.name_camel}Descriptor"
|
||||
|
||||
def decl_descriptor(self) -> str:
|
||||
"""
|
||||
Declare the descriptor type
|
||||
"""
|
||||
return f"\nusing {self.descriptor} = cutlass::epilogue::collective::detail::Sm100AuxLoadDescriptor<EpilogueDescriptor, {self.stride_mnl}, {DataTypeTag[self.element]}>;\n"
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = self.decl_descriptor()
|
||||
self._type_decl += f"""
|
||||
using {self.name_camel} = cutlass::epilogue::fusion::Sm90AuxLoad<
|
||||
{self.descriptor}::Stages, typename {self.descriptor}::EpilogueTile, {DataTypeTag[self.element]},
|
||||
{self.stride_mnl}, typename {self.descriptor}::SmemLayoutAtom, typename {self.descriptor}::CopyOpS2R
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
def get_smem_size(
|
||||
self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles
|
||||
):
|
||||
"""
|
||||
Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d
|
||||
"""
|
||||
return (
|
||||
DataTypeSize[self.element] * stages_c * product(epilogue_tile_mn) // 8,
|
||||
128,
|
||||
)
|
||||
|
||||
|
||||
class Sm100AuxStoreImpl(AuxStoreImpl):
|
||||
@property
|
||||
def descriptor(self) -> str:
|
||||
"""
|
||||
Descriptor for Aux Load
|
||||
"""
|
||||
return f"{self.name_camel}Descriptor"
|
||||
|
||||
def decl_descriptor(self) -> str:
|
||||
"""
|
||||
Declare the descriptor type
|
||||
"""
|
||||
return f"""
|
||||
using {self.descriptor} = cutlass::epilogue::collective::detail::Sm100AuxStoreDescriptor<
|
||||
EpilogueDescriptor, {self.stride_mnl}, {DataTypeTag[self.element]}
|
||||
>;
|
||||
"""
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = self.decl_descriptor()
|
||||
self._type_decl += f"""
|
||||
using {self.name_camel} = cutlass::epilogue::fusion::Sm90AuxStore<
|
||||
{self.descriptor}::Stages, typename {self.descriptor}::EpilogueTile, {DataTypeTag[self.element]},
|
||||
{FloatRoundStyleTag[self.round_style]}, {self.stride_mnl}, typename {self.descriptor}::SmemLayoutAtom,
|
||||
typename {self.descriptor}::CopyOpR2S
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
def get_smem_size(
|
||||
self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles
|
||||
):
|
||||
"""
|
||||
Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d
|
||||
"""
|
||||
return (
|
||||
DataTypeSize[self.element] * stages_d * product(epilogue_tile_mn) // 8,
|
||||
128,
|
||||
)
|
||||
@@ -0,0 +1,46 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Emitter for Sm80 Epilogue Visitor
|
||||
"""
|
||||
|
||||
from cutlass_api.fusion.backend.emitter_base import FusionCallbacks
|
||||
|
||||
|
||||
class Sm80Emitter:
|
||||
def __init__(self, operation, graph) -> None:
|
||||
self.fusion_callbacks = FusionCallbacks(graph, cc=80)
|
||||
|
||||
def emit(self):
|
||||
callback_decl, callback_name = self.fusion_callbacks.emit()
|
||||
return callback_name, callback_decl
|
||||
247
python/cutlass_api/cutlass_api/fusion/backend/sm80_nodes.py
Normal file
247
python/cutlass_api/cutlass_api/fusion/backend/sm80_nodes.py
Normal file
@@ -0,0 +1,247 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from cutlass_api.fusion.library import (
|
||||
DataTypeTag,
|
||||
FloatRoundStyleTag,
|
||||
FunctionalOp,
|
||||
op_tag,
|
||||
)
|
||||
|
||||
from cutlass_api.fusion.ir import (
|
||||
# Load Node
|
||||
AccumulatorImpl,
|
||||
AuxLoadImpl,
|
||||
ColumnBroadcastImpl,
|
||||
LoadNode,
|
||||
RowBroadcastImpl,
|
||||
ScalarBroadcastImpl,
|
||||
# Compute Node
|
||||
ComputeImpl,
|
||||
# Store Node
|
||||
AuxStoreImpl,
|
||||
ColumnReductionImpl,
|
||||
RowReductionImpl,
|
||||
ScalarReductionImpl,
|
||||
)
|
||||
|
||||
|
||||
class Sm80AccumulatorImpl(AccumulatorImpl):
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""\nusing {self.name_camel} = cutlass::epilogue::threadblock::VisitorAccFetch;\n"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm80AuxLoadImpl(AuxLoadImpl):
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorAuxLoad<
|
||||
OutputTileThreadMap, {DataTypeTag[self.element]}, {self.stride_mnl}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm80LoadSrcImpl(Sm80AuxLoadImpl):
|
||||
pass
|
||||
|
||||
|
||||
class Sm80ScalarBroadcastImpl(ScalarBroadcastImpl):
|
||||
def __init__(self, node: LoadNode) -> None:
|
||||
super().__init__(node)
|
||||
self.broadcast_count = 1
|
||||
self.reduction_fn = FunctionalOp.Multiplies
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorScalarBroadcast<
|
||||
{DataTypeTag[self.element]}, {self.stride_mnl}, {self.broadcast_count}, {op_tag(self.reduction_fn)}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm80RowBroadcastImpl(RowBroadcastImpl):
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorRowBroadcast<
|
||||
OutputTileThreadMap, {DataTypeTag[self.element]},
|
||||
{self.stride_mnl}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm80ColumnBroadcastImpl(ColumnBroadcastImpl):
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorColBroadcast<
|
||||
OutputTileThreadMap, {DataTypeTag[self.element]},
|
||||
{self.stride_mnl}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm80ComputeImpl(ComputeImpl):
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
{op_tag(self.fn)}, {DataTypeTag[self.element_output]}, {DataTypeTag[self.element_compute]},
|
||||
{FloatRoundStyleTag[self.round_style]}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm80AuxStoreImpl(AuxStoreImpl):
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorAuxStore<
|
||||
OutputTileThreadMap, {DataTypeTag[self.element]}, {FloatRoundStyleTag[self.round_style]},
|
||||
{self.stride_mnl}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm80StoreDImpl(Sm80AuxStoreImpl):
|
||||
pass
|
||||
|
||||
|
||||
class Sm80ColumnReductionImpl(ColumnReductionImpl):
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorColReduction<
|
||||
{op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)},
|
||||
OutputTileThreadMap, {DataTypeTag[self.element]},
|
||||
{DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]},
|
||||
{self.stride_mnl}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm80RowReductionImpl(RowReductionImpl):
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorRowReduction<
|
||||
{op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)},
|
||||
OutputTileThreadMap, {DataTypeTag[self.element]},
|
||||
{DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]},
|
||||
{self.stride_mnl}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm80ScalarReductionImpl(ScalarReductionImpl):
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorScalarReduction<
|
||||
{op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)},
|
||||
OutputTileThreadMap, {DataTypeTag[self.element]},
|
||||
{DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]},
|
||||
{self.stride_mnl}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
@@ -0,0 +1,97 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Emitter for Sm90 Epilogue Visitor
|
||||
"""
|
||||
|
||||
from cutlass_api.fusion.library import DataTypeTag, EpilogueScheduleTag
|
||||
from cutlass_api.fusion.backend.emitter_base import FusionCallbacks
|
||||
|
||||
|
||||
class CollectiveEpilogue:
|
||||
def __init__(
|
||||
self, tile_description, schedule, element_c, element_d, fusion_callbacks
|
||||
) -> None:
|
||||
self.cta_tile_mnk = tile_description.threadblock_shape
|
||||
self.element_c = element_c
|
||||
self.element_d = element_d
|
||||
self.schedule = schedule
|
||||
self.fusion_callbacks = fusion_callbacks
|
||||
|
||||
@property
|
||||
def CtaTileMNK(self) -> str:
|
||||
"""
|
||||
The threadblock shape
|
||||
"""
|
||||
return f"cute::Shape<_{self.cta_tile_mnk[0]}, _{self.cta_tile_mnk[1]}, _{self.cta_tile_mnk[2]}>"
|
||||
|
||||
@property
|
||||
def EpilogueTileType(self) -> str:
|
||||
"""
|
||||
The epilogue tile type
|
||||
"""
|
||||
return "cutlass::epilogue::collective::EpilogueTileAuto"
|
||||
|
||||
@property
|
||||
def Schedule(self) -> str:
|
||||
return EpilogueScheduleTag[self.schedule]
|
||||
|
||||
def emit(self):
|
||||
callback_decl, callback_name = self.fusion_callbacks.emit()
|
||||
return (
|
||||
callback_name,
|
||||
f"""
|
||||
using EpilogueDescriptor = cutlass::epilogue::collective::detail::EpilogueDescriptor<
|
||||
{self.CtaTileMNK}, {self.EpilogueTileType},
|
||||
{DataTypeTag[self.element_c]}, {DataTypeTag[self.element_d]},
|
||||
{self.Schedule}
|
||||
>;
|
||||
{callback_decl}
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
class Sm90Emitter:
|
||||
def __init__(self, operation, graph) -> None:
|
||||
fusion_callbacks = FusionCallbacks(graph, cc=90, emit_CD=False)
|
||||
|
||||
self.collective_epilogue = CollectiveEpilogue(
|
||||
tile_description=operation.tile_description,
|
||||
schedule=operation.tile_description.epilogue_schedule,
|
||||
element_c=operation.C.element,
|
||||
element_d=operation.D.element,
|
||||
fusion_callbacks=fusion_callbacks,
|
||||
)
|
||||
|
||||
def emit(self):
|
||||
return self.collective_epilogue.emit()
|
||||
327
python/cutlass_api/cutlass_api/fusion/backend/sm90_nodes.py
Normal file
327
python/cutlass_api/cutlass_api/fusion/backend/sm90_nodes.py
Normal file
@@ -0,0 +1,327 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
|
||||
from cutlass_api.fusion.library import (
|
||||
DataTypeSize,
|
||||
DataTypeTag,
|
||||
FloatRoundStyleTag,
|
||||
FunctionalOp,
|
||||
op_tag,
|
||||
)
|
||||
from cutlass_api.fusion.ir import (
|
||||
# Load Node
|
||||
AccumulatorImpl,
|
||||
AuxLoadImpl,
|
||||
ColumnBroadcastImpl,
|
||||
LoadNode,
|
||||
LoadSrcImpl,
|
||||
RowBroadcastImpl,
|
||||
ScalarBroadcastImpl,
|
||||
# Compute Node
|
||||
ComputeImpl,
|
||||
# Store Node
|
||||
AuxStoreImpl,
|
||||
ColumnReductionImpl,
|
||||
RowReductionImpl,
|
||||
ScalarReductionImpl,
|
||||
StoreDImpl,
|
||||
)
|
||||
from cutlass_api.fusion.pycute import product
|
||||
|
||||
|
||||
class Sm90AccumulatorImpl(AccumulatorImpl):
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""\nusing {self.name_camel} = cutlass::epilogue::fusion::Sm90AccFetch;\n"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm90LoadSrcImpl(LoadSrcImpl):
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using ElementC = {DataTypeTag[self.element]};
|
||||
using StrideC = {self.stride_mnl};
|
||||
using {self.name_camel} = cutlass::epilogue::fusion::Sm90SrcFetch<{DataTypeTag[self.element]}>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm90AuxLoadImpl(AuxLoadImpl):
|
||||
@property
|
||||
def descriptor(self) -> str:
|
||||
"""
|
||||
Descriptor for Aux Load
|
||||
"""
|
||||
return f"{self.name_camel}Descriptor"
|
||||
|
||||
def decl_descriptor(self) -> str:
|
||||
"""
|
||||
Declare the descriptor type
|
||||
"""
|
||||
return f"\nusing {self.descriptor} = cutlass::epilogue::collective::detail::AuxLoadDescriptor<EpilogueDescriptor, {self.stride_mnl}, {DataTypeTag[self.element]}>;\n"
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = self.decl_descriptor()
|
||||
self._type_decl += f"""
|
||||
using {self.name_camel} = cutlass::epilogue::fusion::Sm90AuxLoad<
|
||||
{self.descriptor}::Stages, typename {self.descriptor}::EpilogueTile, {DataTypeTag[self.element]},
|
||||
{self.stride_mnl}, typename {self.descriptor}::SmemLayoutAtom, typename {self.descriptor}::CopyOpS2R
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
def get_smem_size(
|
||||
self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles
|
||||
):
|
||||
"""
|
||||
Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d
|
||||
"""
|
||||
return (
|
||||
DataTypeSize[self.element] * stages_c * product(epilogue_tile_mn) // 8,
|
||||
128,
|
||||
)
|
||||
|
||||
|
||||
class Sm90ScalarBroadcastImpl(ScalarBroadcastImpl):
|
||||
def __init__(self, node: LoadNode) -> None:
|
||||
super().__init__(node)
|
||||
self.broadcast_count = 1
|
||||
self.reduction_fn = FunctionalOp.Multiplies
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::fusion::Sm90ScalarBroadcast<
|
||||
{DataTypeTag[self.element]}, {self.stride_mnl}, {self.broadcast_count}, {op_tag(self.reduction_fn)}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm90RowBroadcastImpl(RowBroadcastImpl):
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::fusion::Sm90RowBroadcast<
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]}, {DataTypeTag[self.element_output]},
|
||||
{self.stride_mnl}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm90ColumnBroadcastImpl(ColumnBroadcastImpl):
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::fusion::Sm90ColBroadcast<
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]}, {DataTypeTag[self.element_output]},
|
||||
{self.stride_mnl}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm90ComputeImpl(ComputeImpl):
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::fusion::Sm90Compute<
|
||||
{op_tag(self.fn)}, {DataTypeTag[self.element_output]}, {DataTypeTag[self.element_compute]},
|
||||
{FloatRoundStyleTag[self.round_style]}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm90AuxStoreImpl(AuxStoreImpl):
|
||||
@property
|
||||
def descriptor(self) -> str:
|
||||
"""
|
||||
Descriptor for Aux Load
|
||||
"""
|
||||
return f"{self.name_camel}Descriptor"
|
||||
|
||||
def decl_descriptor(self) -> str:
|
||||
"""
|
||||
Declare the descriptor type
|
||||
"""
|
||||
return f"""
|
||||
using {self.descriptor} = cutlass::epilogue::collective::detail::AuxStoreDescriptor<
|
||||
EpilogueDescriptor, {self.stride_mnl}, {DataTypeTag[self.element]}
|
||||
>;
|
||||
"""
|
||||
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = self.decl_descriptor()
|
||||
self._type_decl += f"""
|
||||
using {self.name_camel} = cutlass::epilogue::fusion::Sm90AuxStore<
|
||||
{self.descriptor}::Stages, typename {self.descriptor}::EpilogueTile, {DataTypeTag[self.element]},
|
||||
{FloatRoundStyleTag[self.round_style]}, {self.stride_mnl}, typename {self.descriptor}::SmemLayoutAtom,
|
||||
typename {self.descriptor}::CopyOpR2S
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
def get_smem_size(
|
||||
self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles
|
||||
):
|
||||
"""
|
||||
Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d
|
||||
"""
|
||||
return (
|
||||
DataTypeSize[self.element] * stages_d * product(epilogue_tile_mn) // 8,
|
||||
128,
|
||||
)
|
||||
|
||||
|
||||
class Sm90StoreDImpl(StoreDImpl):
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
return f"""
|
||||
using ElementD = {DataTypeTag[self.element]};
|
||||
using StrideD = {self.stride_mnl};
|
||||
"""
|
||||
|
||||
|
||||
class Sm90ColumnReductionImpl(ColumnReductionImpl):
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::fusion::Sm90ColReduction<
|
||||
{op_tag(self.reg_reduce_fn)}, {op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)}, 0,
|
||||
typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]},
|
||||
{DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]},
|
||||
{self.stride_mnl}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm90RowReductionImpl(RowReductionImpl):
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::fusion::Sm90RowReduction<
|
||||
{op_tag(self.reg_reduce_fn)}, {op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)}, 0 /* Stages */,
|
||||
typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]},
|
||||
{DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]},
|
||||
{self.stride_mnl}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
|
||||
|
||||
class Sm90ScalarReductionImpl(ScalarReductionImpl):
|
||||
@property
|
||||
def type_decl(self):
|
||||
"""
|
||||
Return the string defining the type
|
||||
"""
|
||||
if self._type_decl is not None:
|
||||
return self._type_decl
|
||||
|
||||
self._type_decl = f"""
|
||||
using {self.name_camel} = cutlass::epilogue::fusion::Sm90ScalarReduction<
|
||||
{op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)},
|
||||
{DataTypeTag[self.element]}, {DataTypeTag[self.element_compute]},
|
||||
{FloatRoundStyleTag[self.round_style]}, {self.stride_mnl}
|
||||
>;
|
||||
"""
|
||||
return self._type_decl
|
||||
88
python/cutlass_api/cutlass_api/fusion/epilogue.py
Normal file
88
python/cutlass_api/cutlass_api/fusion/epilogue.py
Normal file
@@ -0,0 +1,88 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Epilogue Visitor interface for compiling, and running visitor-based epilogue.
|
||||
"""
|
||||
|
||||
import cutlass_api.fusion.backend
|
||||
from cutlass_api.fusion.passes.util import cc_map
|
||||
|
||||
|
||||
class EpilogueFunctorVisitor:
|
||||
"""
|
||||
Apply an epilogue functor described by the epilogue EVT
|
||||
|
||||
:param cc: compute capability
|
||||
:param visitor_frontend: user-provide visitor frontend
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, cc: int, visitor, element_compute) -> None:
|
||||
# Type of Emitter based on CC
|
||||
self.emit_cls = getattr(
|
||||
cutlass_api.fusion.backend, f"Sm{cc_map[cc]}Emitter"
|
||||
)
|
||||
|
||||
# Visitor Types
|
||||
self.visitor = visitor
|
||||
self.graph = visitor.dag_ir
|
||||
|
||||
# Data types
|
||||
self.element_epilogue = element_compute # element compute
|
||||
self.element_output = self.graph.get_node_meta("D").underlying_impl.element
|
||||
|
||||
# Epilogue Thread Type
|
||||
if cc_map[cc] in [90, 100]:
|
||||
self.arg_c_type = self.visitor.arg_c_type
|
||||
self.arg_d_type = self.visitor.arg_d_type
|
||||
|
||||
# Epilogue stages specialized for sm80 kernel
|
||||
if cc == 80:
|
||||
if hasattr(self.visitor, "epilogue_stages"):
|
||||
self.epilogue_stages = self.visitor.epilogue_stages
|
||||
assert self.epilogue_stages <= 2, (
|
||||
"Only supports Stages <=2 in SM80 Epilogue"
|
||||
)
|
||||
|
||||
def emit(self, operation):
|
||||
"""
|
||||
Emit the C++ code
|
||||
"""
|
||||
emitter = self.emit_cls(operation, self.graph)
|
||||
return emitter.emit()
|
||||
|
||||
def get_smem_size(self, tile_description):
|
||||
"""
|
||||
Get the shared memory size in bytes
|
||||
"""
|
||||
return self.visitor.get_smem_size(tile_description)
|
||||
120
python/cutlass_api/cutlass_api/fusion/evt_ops.py
Normal file
120
python/cutlass_api/cutlass_api/fusion/evt_ops.py
Normal file
@@ -0,0 +1,120 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Collection of builtin functions used for host reference in EVT
|
||||
"""
|
||||
|
||||
from cutlass_api.utils import (
|
||||
is_numpy_available,
|
||||
is_numpy_tensor,
|
||||
is_torch_available,
|
||||
is_torch_tensor,
|
||||
)
|
||||
|
||||
if is_numpy_available():
|
||||
import numpy as np
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
def multiply_add(x, y, z):
|
||||
return x * y + z
|
||||
|
||||
|
||||
def sum(x, dim):
|
||||
if is_numpy_tensor(x):
|
||||
return x.sum(axis=tuple(dim))
|
||||
elif is_torch_tensor(x):
|
||||
return torch.sum(x, dim)
|
||||
else:
|
||||
raise TypeError(f"Unsupported tensor type: {type(x)}")
|
||||
|
||||
|
||||
def max(x, dim):
|
||||
if is_numpy_tensor(x):
|
||||
return x.max(axis=tuple(dim))
|
||||
elif is_torch_tensor(x):
|
||||
return torch.amax(x, dim)
|
||||
else:
|
||||
raise TypeError(f"Unsupported tensor type: {type(x)}")
|
||||
|
||||
|
||||
def maximum(x, y):
|
||||
if is_numpy_tensor(x):
|
||||
return np.maximum(x, y)
|
||||
elif is_torch_tensor(x):
|
||||
return torch.maximum(x, torch.tensor(y))
|
||||
else:
|
||||
raise TypeError(f"Unsupported tensor type: {type(x)}")
|
||||
|
||||
|
||||
def minimum(x, y):
|
||||
if is_numpy_tensor(x):
|
||||
return np.minimum(x, y)
|
||||
elif is_torch_tensor(x):
|
||||
return torch.minimum(x, torch.tensor(y))
|
||||
else:
|
||||
raise TypeError(f"Unsupported tensor type: {type(x)}")
|
||||
|
||||
|
||||
def exp(x):
|
||||
if is_numpy_tensor(x):
|
||||
return np.exp(x)
|
||||
elif is_torch_tensor(x):
|
||||
return torch.exp(x)
|
||||
else:
|
||||
raise TypeError(f"Unsupported tensor type: {type(x)}")
|
||||
|
||||
|
||||
##############################################################################
|
||||
# Layout manipulation nodes
|
||||
##############################################################################
|
||||
|
||||
|
||||
def permute(x, indices: tuple):
|
||||
if is_numpy_tensor(x):
|
||||
return np.transpose(x, axes=indices)
|
||||
elif is_torch_tensor(x):
|
||||
return x.permute(*indices)
|
||||
else:
|
||||
raise TypeError(f"Unsupported tensor type: {type(x)}")
|
||||
|
||||
|
||||
def reshape(x, new_shape: tuple):
|
||||
if is_numpy_tensor(x):
|
||||
return np.reshape(x, newshape=new_shape)
|
||||
elif is_torch_tensor(x):
|
||||
return x.view(new_shape)
|
||||
else:
|
||||
raise TypeError(f"Unsupported tensor type: {type(x)}")
|
||||
41
python/cutlass_api/cutlass_api/fusion/frontend/__init__.py
Normal file
41
python/cutlass_api/cutlass_api/fusion/frontend/__init__.py
Normal file
@@ -0,0 +1,41 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from cutlass_api.fusion.frontend.python_ast import (
|
||||
PythonASTFrontend,
|
||||
PythonASTInOutProcessor,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"PythonASTFrontend",
|
||||
"PythonASTInOutProcessor",
|
||||
]
|
||||
300
python/cutlass_api/cutlass_api/fusion/frontend/frontend_base.py
Normal file
300
python/cutlass_api/cutlass_api/fusion/frontend/frontend_base.py
Normal file
@@ -0,0 +1,300 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Base class for Python EVT Frontend
|
||||
"""
|
||||
|
||||
from typing import Union
|
||||
|
||||
from cutlass_api.fusion.library import DataType
|
||||
from cutlass_api.fusion.ir import (
|
||||
ComputeNode,
|
||||
DAGIR,
|
||||
LayoutNode,
|
||||
LoadNode,
|
||||
StoreNode,
|
||||
)
|
||||
from cutlass_api.fusion.passes import (
|
||||
EVTGraphDrawer,
|
||||
EVTPassManager,
|
||||
GetSmemSize,
|
||||
PassDAG2Tree,
|
||||
PassGetArgumentType,
|
||||
PassGetImpl,
|
||||
PassFixElementD,
|
||||
PassLayoutManipulateElimination,
|
||||
PassPreprocessRed,
|
||||
PassShapeTypePropagation,
|
||||
)
|
||||
from cutlass_api.fusion.passes.util import cc_map
|
||||
from cutlass_api.fusion.evt_ops import permute, reshape
|
||||
|
||||
|
||||
class EVTFrontendBase:
|
||||
layout_fns = {"permute": permute, "reshape": reshape}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cc,
|
||||
element_compute=DataType.f32,
|
||||
additional_passes=None,
|
||||
requires_conversion_to_tree=True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.cc = cc
|
||||
self.element_compute = element_compute
|
||||
self.dag_ir = DAGIR(self.cc, self.element_compute)
|
||||
self.compute_cnt = 0
|
||||
self.layout_cnt = 0
|
||||
self.imm_cnt = 0
|
||||
|
||||
if additional_passes is None:
|
||||
additional_passes = []
|
||||
|
||||
passes = [
|
||||
PassPreprocessRed,
|
||||
PassGetArgumentType,
|
||||
PassShapeTypePropagation,
|
||||
PassLayoutManipulateElimination,
|
||||
PassGetImpl,
|
||||
PassDAG2Tree,
|
||||
PassFixElementD,
|
||||
] + additional_passes
|
||||
|
||||
soft_dependencies = []
|
||||
if not requires_conversion_to_tree:
|
||||
passes.remove(PassDAG2Tree)
|
||||
soft_dependencies.append(PassDAG2Tree)
|
||||
|
||||
self.pass_manager = EVTPassManager(
|
||||
self.dag_ir, passes, soft_dependencies=soft_dependencies
|
||||
)
|
||||
|
||||
if self.cc == 80:
|
||||
self._epilogue_stages = 1
|
||||
else:
|
||||
self._epilogue_stages = None
|
||||
|
||||
@property
|
||||
def epilogue_stages(self):
|
||||
return self._epilogue_stages
|
||||
|
||||
@epilogue_stages.setter
|
||||
def epilogue_stages(self, stages):
|
||||
self._epilogue_stages = stages
|
||||
|
||||
def parse(self, *args, **kwargs):
|
||||
raise NotImplementedError(
|
||||
"The 'parse' function must be overloaded in frontend class"
|
||||
)
|
||||
|
||||
def trace(self, *args, **kwargs):
|
||||
# Parse the input
|
||||
self.parse(*args, **kwargs)
|
||||
|
||||
# Verify the DAG IR to ensure that "D" is the output node with out_degree = 0
|
||||
if self.cc >= 90:
|
||||
if self.dag_ir.out_degree("D") != 0:
|
||||
raise RuntimeError(
|
||||
f"On SM90 or higher, D is expected to be a output node with 0 users to "
|
||||
f"enable smem reuse between C and D, but got {self.dag_ir.out_degree('D')}"
|
||||
)
|
||||
|
||||
# Run the passes
|
||||
self.pass_manager()
|
||||
# Set the epilogue type
|
||||
self.epilogue_thread_type = self.dag_ir.epilogue_thread_type
|
||||
if cc_map[self.cc] in [90, 100]:
|
||||
self.arg_c_type = self.dag_ir.arg_c_type
|
||||
self.arg_d_type = self.dag_ir.arg_d_type
|
||||
self.reduction_names = self.dag_ir.reduction_names
|
||||
|
||||
#
|
||||
# Helper functions for DAG IR manipulation
|
||||
#
|
||||
|
||||
def add_node(self, node):
|
||||
self.dag_ir.add_node(node)
|
||||
|
||||
def add_edge(self, src, tgt, weight=0):
|
||||
self.dag_ir.add_edge(src, tgt, weight=weight)
|
||||
|
||||
def set_tensor(self, node_name, example):
|
||||
"""
|
||||
Add an example tensor to node {node_name} in the DAG IR
|
||||
"""
|
||||
meta = self.dag_ir.get_node_meta(node_name)
|
||||
meta.tensor = {"tensor": example}
|
||||
|
||||
def set_store_tensor(self, node_name, example):
|
||||
"""
|
||||
Add an example tensor to node {node_name} in the DAG IR
|
||||
"""
|
||||
meta = self.dag_ir.get_node_meta(node_name)
|
||||
meta.store_tensor = {"tensor": example}
|
||||
|
||||
def mark_output(self, node_name):
|
||||
"""
|
||||
Mark a store node as output
|
||||
"""
|
||||
meta = self.dag_ir.get_node_meta(node_name)
|
||||
# Allow accum to also be used as an output so that
|
||||
# a user can include it in the return line
|
||||
if not isinstance(meta, StoreNode) and not meta.name == "accum":
|
||||
raise ValueError(
|
||||
f"Only StoreNodes can be marked as output. "
|
||||
f"Got {type(meta).__name__}: {node_name}"
|
||||
)
|
||||
meta.is_output = True
|
||||
|
||||
# Add node with specific type
|
||||
|
||||
def add_load_node(self, name, example):
|
||||
"""
|
||||
Add a Load node to DAG IR
|
||||
:param name: name of the loaded variable
|
||||
:type name: str
|
||||
:param example: example input
|
||||
:type example: np.ndarray|torch.Tensor|cupy.ndarray|float
|
||||
"""
|
||||
if name is None:
|
||||
raise ValueError("Name is not provided.")
|
||||
if example is None:
|
||||
raise ValueError(f"Example input for {name} is not provided.")
|
||||
load_node = LoadNode(name)
|
||||
load_node.tensor = {"tensor": example}
|
||||
# Special logics for accumulator
|
||||
if name == "accum":
|
||||
if load_node.tensor.rank == 2:
|
||||
new_shape = tuple(
|
||||
[
|
||||
1,
|
||||
]
|
||||
+ list(load_node.tensor.shape)
|
||||
)
|
||||
load_node.tensor.broadcast(new_shape)
|
||||
elif load_node.tensor.rank < 2 or load_node.tensor.rank > 3:
|
||||
raise ValueError(
|
||||
f"Expect example inputs for 'accum' be a rank-2 or rank-3 tensor. Got {load_node.tensor.shape}."
|
||||
)
|
||||
self.add_node(load_node)
|
||||
|
||||
def add_imm(self, value: Union[float, int]):
|
||||
"""
|
||||
Add an immediate scalar value to DAG IR
|
||||
:param value: the value of the immediate scalar
|
||||
:type value: float
|
||||
"""
|
||||
try:
|
||||
value = float(value)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"{type(value).__name__} cannot be converted to float."
|
||||
) from e
|
||||
|
||||
name = f"imm_{value}_k{self.imm_cnt}".replace(".", "_")
|
||||
self.imm_cnt += 1
|
||||
load_node = LoadNode(name)
|
||||
load_node.tensor = {"tensor": value, "is_constant": True}
|
||||
self.add_node(load_node)
|
||||
return name
|
||||
|
||||
def add_compute_node(self, op, name=None):
|
||||
"""
|
||||
Add a compute node.
|
||||
:param op: the computation op
|
||||
:param name: the node name (optional)
|
||||
:type name: str
|
||||
:return: the name of the compute node
|
||||
"""
|
||||
if name is None:
|
||||
name = f"compute_{self.compute_cnt}"
|
||||
self.compute_cnt += 1
|
||||
compute_node = ComputeNode(
|
||||
name=name,
|
||||
fn=op,
|
||||
element_output=self.element_compute,
|
||||
element_compute=self.element_compute,
|
||||
)
|
||||
self.add_node(compute_node)
|
||||
return compute_node.name
|
||||
|
||||
def add_layout_node(self, op, kwargs, name=None):
|
||||
"""
|
||||
Add a layout node.
|
||||
:param op: the layout op
|
||||
:type op: evt_ops
|
||||
:param name: the node name (optional)
|
||||
:type name: str
|
||||
:return: the name of the layout node
|
||||
"""
|
||||
if name is None:
|
||||
name = f"layout_{self.layout_cnt}"
|
||||
self.layout_cnt += 1
|
||||
layout_node = LayoutNode(name=name, fn=op, kwargs=kwargs)
|
||||
self.add_node(layout_node)
|
||||
return layout_node.name
|
||||
|
||||
def add_store_node(self, name):
|
||||
store_node = StoreNode(name)
|
||||
self.add_node(store_node)
|
||||
|
||||
#
|
||||
# Visualization The DAG IR
|
||||
#
|
||||
|
||||
def visualize(self, name="dag_ir"):
|
||||
"""
|
||||
Visualize the dag ir with svg file
|
||||
:param name: the name of the graph
|
||||
"""
|
||||
drawer = EVTGraphDrawer(self.dag_ir, name)
|
||||
try:
|
||||
for name, graph in drawer.get_dot_graph():
|
||||
graph.write_svg(f"./{name}.svg")
|
||||
except:
|
||||
raise RuntimeError(
|
||||
"'dot' is not found in path. GraphDrawer is disabled. "
|
||||
"Please install it with 'sudo apt-get install graphviz'."
|
||||
)
|
||||
|
||||
#
|
||||
# Get shared memory size
|
||||
#
|
||||
|
||||
def get_smem_size(self, tile_description):
|
||||
"""
|
||||
Get the shared memory size of the epilogue
|
||||
"""
|
||||
smem_size = GetSmemSize(self.dag_ir)(tile_description)
|
||||
return smem_size
|
||||
262
python/cutlass_api/cutlass_api/fusion/frontend/python_ast.py
Normal file
262
python/cutlass_api/cutlass_api/fusion/frontend/python_ast.py
Normal file
@@ -0,0 +1,262 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Python AST frontend that parses input into DAG IR
|
||||
"""
|
||||
|
||||
import ast
|
||||
import inspect
|
||||
import textwrap
|
||||
from typing import Callable, Union
|
||||
|
||||
|
||||
from cutlass_api.fusion.activation import (
|
||||
identity,
|
||||
relu,
|
||||
tanh,
|
||||
sigmoid,
|
||||
silu,
|
||||
hardswish,
|
||||
gelu,
|
||||
)
|
||||
from cutlass_api.fusion.frontend.frontend_base import EVTFrontendBase
|
||||
from cutlass_api.fusion.library import DataType, FunctionalOp
|
||||
|
||||
|
||||
class PythonASTFrontend(EVTFrontendBase, ast.NodeVisitor):
|
||||
def __init__(self, cc, element_compute=DataType.f32, **kwargs):
|
||||
super().__init__(cc, element_compute, **kwargs)
|
||||
# Flags
|
||||
# If this state is True, visit_Constant returns values without creating imm node
|
||||
self.no_imm = False
|
||||
self.visiting_return = False
|
||||
|
||||
def parse(self, example_inputs):
|
||||
self.example_inputs = example_inputs
|
||||
self.source = textwrap.dedent(inspect.getsource(self.__call__))
|
||||
self.ast = ast.parse(self.source)
|
||||
self.visit(self.ast)
|
||||
|
||||
#
|
||||
# Helper functions
|
||||
#
|
||||
@staticmethod
|
||||
def ast_op_to_bindings(op):
|
||||
mapping = {
|
||||
ast.Add: FunctionalOp.Plus,
|
||||
ast.Sub: FunctionalOp.Minus,
|
||||
ast.Mult: FunctionalOp.Multiplies,
|
||||
ast.Div: FunctionalOp.Divides,
|
||||
"maximum": FunctionalOp.Maximum,
|
||||
"minimum": FunctionalOp.Minimum,
|
||||
"identity": identity.binding_type,
|
||||
"relu": relu.binding_type,
|
||||
"tanh": tanh.binding_type,
|
||||
"sigmoid": sigmoid.binding_type,
|
||||
"silu": silu.binding_type,
|
||||
"hardswish": hardswish.binding_type,
|
||||
"gelu": gelu.binding_type,
|
||||
"multiply_add": FunctionalOp.MultiplyAdd,
|
||||
"sum": (FunctionalOp.Plus, FunctionalOp.AtomicAdd),
|
||||
"max": (FunctionalOp.Maximum, FunctionalOp.AtomicMaximum),
|
||||
"exp": FunctionalOp.Exp,
|
||||
}
|
||||
return mapping[op]
|
||||
|
||||
#
|
||||
# Visiting different node types
|
||||
#
|
||||
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef):
|
||||
# Visit args and register load nodes
|
||||
for arg in node.args.args:
|
||||
self.visit(arg)
|
||||
for expr in node.body:
|
||||
self.visit(expr)
|
||||
|
||||
def visit_arg(self, node: ast.arg):
|
||||
# Name of the argument
|
||||
name = node.arg
|
||||
try:
|
||||
example_tensor = self.example_inputs[name]
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Example input for {name} is not provided.") from e
|
||||
|
||||
self.add_load_node(name, example_tensor)
|
||||
|
||||
def visit_Name(self, node: ast.Name):
|
||||
return node.id
|
||||
|
||||
def visit_Constant(self, node: ast.Constant):
|
||||
if self.no_imm:
|
||||
return node.value
|
||||
else:
|
||||
name = self.add_imm(node.value)
|
||||
return name
|
||||
|
||||
def visit_Tuple(self, node: ast.Tuple):
|
||||
results = []
|
||||
for elt in node.elts:
|
||||
results.append(self.visit(elt))
|
||||
return tuple(results)
|
||||
|
||||
def visit_keyword(self, node: ast.keyword):
|
||||
return {node.arg: self.visit(node.value)}
|
||||
|
||||
def visit_BinOp(self, node: ast.BinOp):
|
||||
if self.visiting_return:
|
||||
raise SyntaxError("Return value cannot be an expression")
|
||||
lhs = self.visit(node.left)
|
||||
rhs = self.visit(node.right)
|
||||
op = self.ast_op_to_bindings(type(node.op))
|
||||
name = self.add_compute_node(op)
|
||||
|
||||
# Add edges
|
||||
# The edge weights are used to sort the input args
|
||||
self.add_edge(lhs, name, weight=0)
|
||||
self.add_edge(rhs, name, weight=1)
|
||||
return name
|
||||
|
||||
def visit_Assign(self, node: ast.Assign):
|
||||
target = self.visit(node.targets[0])
|
||||
value = self.visit(node.value)
|
||||
# Create the assign node
|
||||
self.add_store_node(target)
|
||||
|
||||
# Add edges
|
||||
self.add_edge(value, target)
|
||||
return target
|
||||
|
||||
def visit_Call(self, node: ast.Call):
|
||||
if self.visiting_return:
|
||||
raise SyntaxError("Return value cannot be an expression")
|
||||
func = self.visit(node.func)
|
||||
args = [self.visit(arg) for arg in node.args]
|
||||
|
||||
if func in self.layout_fns.keys():
|
||||
# Parse kwargs
|
||||
# By default, visiting imm automatically creates a load node
|
||||
# However, in function call, keyword args are used to set
|
||||
# specific function attributes such as indices for permute
|
||||
# So no_imm is set to True temporarily
|
||||
self.no_imm = True
|
||||
kwargs = {}
|
||||
for kw in node.keywords:
|
||||
kwargs.update(self.visit(kw))
|
||||
self.no_imm = False
|
||||
op = self.layout_fns[func]
|
||||
name = self.add_layout_node(op, kwargs)
|
||||
else:
|
||||
op = self.ast_op_to_bindings(func)
|
||||
name = self.add_compute_node(op)
|
||||
|
||||
# Add edges
|
||||
for idx, arg in enumerate(args):
|
||||
self.add_edge(arg, name, weight=idx)
|
||||
return name
|
||||
|
||||
def visit_Return(self, node: ast.Return):
|
||||
self.visiting_return = True
|
||||
results = self.visit(node.value)
|
||||
self.visiting_return = False
|
||||
self.return_names = results
|
||||
if not isinstance(results, tuple):
|
||||
results = (results,)
|
||||
for rst in results:
|
||||
try:
|
||||
example_tensor = self.example_inputs[rst]
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Example input for {rst} is not provided.") from e
|
||||
self.set_store_tensor(rst, example_tensor)
|
||||
self.mark_output(rst)
|
||||
|
||||
|
||||
class PythonASTInOutProcessor(ast.NodeVisitor):
|
||||
"""
|
||||
Simple processor that traces the input AST and returns the names of inputs
|
||||
and outputs
|
||||
|
||||
:param source: The source code of the function to trace.
|
||||
:type source: Union[Callable, str]
|
||||
:return: A tuple of lists of input and output names.
|
||||
:rtype: tuple[list[str], list[str]]
|
||||
"""
|
||||
|
||||
def trace(self, source: Union[Callable, str]) -> tuple[list[str], list[str]]:
|
||||
if callable(source):
|
||||
self.source = textwrap.dedent(inspect.getsource(source))
|
||||
if isinstance(source, str):
|
||||
self.source = source
|
||||
self.ast = ast.parse(self.source)
|
||||
|
||||
self.inputs = []
|
||||
self.outputs = []
|
||||
self.visiting_return = False
|
||||
|
||||
self.visit(self.ast)
|
||||
if "accum" not in self.inputs:
|
||||
raise ValueError("accum must be an input to the epilogue function")
|
||||
|
||||
# Users do not need to specify accum. Remove it.
|
||||
self.inputs.remove("accum")
|
||||
return self.inputs, self.outputs
|
||||
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef):
|
||||
for arg in node.args.args:
|
||||
self.visit(arg)
|
||||
for expr in node.body:
|
||||
self.visit(expr)
|
||||
|
||||
def visit_arg(self, node: ast.arg):
|
||||
self.inputs.append(node.arg)
|
||||
|
||||
def visit_Constant(self, node: ast.Constant):
|
||||
if self.visiting_return:
|
||||
raise SyntaxError("Return value cannot be a constant")
|
||||
|
||||
def visit_Name(self, node: ast.Name):
|
||||
return node.id
|
||||
|
||||
def visit_Tuple(self, node: ast.Tuple):
|
||||
results = []
|
||||
for elt in node.elts:
|
||||
results.append(self.visit(elt))
|
||||
return tuple(results)
|
||||
|
||||
def visit_Return(self, node: ast.Return):
|
||||
self.visiting_return = True
|
||||
results = self.visit(node.value)
|
||||
self.visiting_return = False
|
||||
if not isinstance(results, tuple):
|
||||
results = (results,)
|
||||
self.outputs.extend(results)
|
||||
75
python/cutlass_api/cutlass_api/fusion/ir/__init__.py
Normal file
75
python/cutlass_api/cutlass_api/fusion/ir/__init__.py
Normal file
@@ -0,0 +1,75 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from cutlass_api.fusion.ir.compute_nodes import ComputeNode, ComputeImpl
|
||||
from cutlass_api.fusion.ir.dag_ir import DAGIR
|
||||
from cutlass_api.fusion.ir.layout_nodes import LayoutNode
|
||||
from cutlass_api.fusion.ir.load_nodes import (
|
||||
LoadNode,
|
||||
AccumulatorImpl,
|
||||
LoadSrcImpl,
|
||||
AuxLoadImpl,
|
||||
RowBroadcastImpl,
|
||||
ColumnBroadcastImpl,
|
||||
ScalarBroadcastImpl,
|
||||
)
|
||||
from cutlass_api.fusion.ir.node import TopoVisitorNode, NoOpImpl
|
||||
from cutlass_api.fusion.ir.store_nodes import (
|
||||
StoreNode,
|
||||
StoreDImpl,
|
||||
AuxStoreImpl,
|
||||
ColumnReductionImpl,
|
||||
RowReductionImpl,
|
||||
ScalarReductionImpl,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ComputeNode",
|
||||
"ComputeImpl",
|
||||
"DAGIR",
|
||||
"LayoutNode",
|
||||
"LoadNode",
|
||||
"AccumulatorImpl",
|
||||
"LoadSrcImpl",
|
||||
"AuxLoadImpl",
|
||||
"RowBroadcastImpl",
|
||||
"ColumnBroadcastImpl",
|
||||
"ScalarBroadcastImpl",
|
||||
"TopoVisitorNode",
|
||||
"NoOpImpl",
|
||||
"StoreNode",
|
||||
"StoreDImpl",
|
||||
"AuxStoreImpl",
|
||||
"ColumnReductionImpl",
|
||||
"RowReductionImpl",
|
||||
"ScalarReductionImpl",
|
||||
]
|
||||
246
python/cutlass_api/cutlass_api/fusion/ir/c_types.py
Normal file
246
python/cutlass_api/cutlass_api/fusion/ir/c_types.py
Normal file
@@ -0,0 +1,246 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
import ctypes
|
||||
|
||||
from cutlass_api.fusion.library import DataType
|
||||
from cutlass_api.utils import is_numpy_tensor, is_torch_tensor, is_numpy_available
|
||||
|
||||
|
||||
dtype2ctype = {
|
||||
DataType.f16: ctypes.c_uint16,
|
||||
DataType.bf16: ctypes.c_uint16,
|
||||
DataType.f32: ctypes.c_float,
|
||||
DataType.f64: ctypes.c_double,
|
||||
DataType.s8: ctypes.c_int8,
|
||||
DataType.s32: ctypes.c_int32,
|
||||
}
|
||||
|
||||
|
||||
def get_scalar(value):
|
||||
"""
|
||||
Returns a scalar value from a container (e.g., np.ndarray)
|
||||
"""
|
||||
if is_numpy_tensor(value) or is_torch_tensor(value):
|
||||
if value.size != 1:
|
||||
raise Exception("Scalars used in epilogue must be of size 1")
|
||||
return value.reshape(-1)[0]
|
||||
else:
|
||||
return value
|
||||
|
||||
|
||||
def to_ctype_value(value, dtype):
|
||||
"""
|
||||
Converts ``value`` to the corresponding storage needed for the ctype that
|
||||
will store ``value``.
|
||||
"""
|
||||
scalar = get_scalar(value)
|
||||
if dtype == DataType.f16:
|
||||
if is_numpy_available():
|
||||
import numpy as np
|
||||
|
||||
# Convert f16 value into an integer
|
||||
return int.from_bytes(np.float16(scalar).tobytes(), "little")
|
||||
else:
|
||||
raise NotImplementedError("Numpy is not available")
|
||||
else:
|
||||
return scalar
|
||||
|
||||
|
||||
|
||||
class Empty(ctypes.Structure):
|
||||
_fields_ = []
|
||||
|
||||
def __init__(self, *arg) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class EmptyByte(ctypes.Structure):
|
||||
_fields_ = [("byte", ctypes.c_byte)]
|
||||
|
||||
def __init__(self, *arg) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class EBO:
|
||||
def __init__(self, index: int, type) -> None:
|
||||
self.index = index
|
||||
self.type = type
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
if isinstance(other, EBO):
|
||||
return self.index == other.index and self.type == other.type
|
||||
return False
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.index, self.type))
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"<{self.index}, {self.type}>"
|
||||
|
||||
|
||||
def tuple_factory_(input_tuple, dtype, constants: list[int] = None):
|
||||
"""
|
||||
The factory function generating cute::Tuple with input tuple
|
||||
:param input_tuple: the input tuple
|
||||
:type input_tuple: tuple
|
||||
:param dtype: the data type for non-constant values
|
||||
:type dtype: str, "int32_t", "int", "int64_t"
|
||||
:param constant: the values that will be treated as constants
|
||||
:type constant: list[int]
|
||||
|
||||
:return: ctype structure representing the cute::Tuple
|
||||
:return: the empty base classes of the tuple
|
||||
"""
|
||||
|
||||
if constants is None:
|
||||
constants = [0, 1]
|
||||
|
||||
# The empty base classes of the current tuple
|
||||
empty_bases = []
|
||||
# The first non empty base class
|
||||
first_non_empty_base = None
|
||||
# The ctype fields of the current tuple
|
||||
ctype_fields = []
|
||||
|
||||
for idx, entry in enumerate(input_tuple):
|
||||
# For nested tuples
|
||||
if isinstance(entry, tuple):
|
||||
sub_tuple_ctype, sub_empty_bases = tuple_factory_(entry, dtype, constants)
|
||||
if ctypes.sizeof(sub_tuple_ctype) == 0:
|
||||
# The empty tuple base class is also an empty EBO
|
||||
empty_bases.append(EBO(idx, entry))
|
||||
else:
|
||||
if first_non_empty_base is None:
|
||||
first_non_empty_base = sub_empty_bases
|
||||
ctype_fields.append((f"entry_{idx}", sub_tuple_ctype))
|
||||
else:
|
||||
if entry in constants:
|
||||
empty_bases.append(EBO(idx, entry))
|
||||
ctype_fields.append((f"entry_{idx}", Empty))
|
||||
else:
|
||||
ctype_fields.append((f"entry_{idx}", dtype))
|
||||
if first_non_empty_base is None:
|
||||
first_non_empty_base = []
|
||||
|
||||
# Create the ctype tuple
|
||||
class TupleType(ctypes.Structure):
|
||||
_fields_ = ctype_fields
|
||||
|
||||
def __init__(self, args) -> None:
|
||||
fields = self._fields_
|
||||
|
||||
assert len(fields) == len(args)
|
||||
for field, arg in zip(fields, args):
|
||||
name = field[0]
|
||||
field_type = field[1]
|
||||
setattr(self, name, field_type(arg))
|
||||
|
||||
return TupleType, empty_bases
|
||||
|
||||
|
||||
def tuple_factory(input_tuple, dtype: str, constants=[0, 1]):
|
||||
"""
|
||||
The factory function generating cute::Tuple with input tuple
|
||||
:param input_tuple: the input tuple
|
||||
:type input_tuple: tuple
|
||||
:param dtype: the data type for non-constant values
|
||||
:type dtype: str, "int32_t", "int", "int64_t"
|
||||
:param constant: the values that will be treated as constants
|
||||
:type constant: list[int]
|
||||
|
||||
:return: ctype structure representing the cute::Tuple
|
||||
:return: the empty base classes of the tuple
|
||||
"""
|
||||
# Step 1: convert the dtype
|
||||
if dtype == "int64_t":
|
||||
dtype = ctypes.c_longlong
|
||||
elif dtype in ["int", "int32_t"]:
|
||||
dtype = ctypes.c_int32
|
||||
else:
|
||||
raise NotImplementedError(f"Type {dtype} is not supported")
|
||||
|
||||
tuple_type, _ = tuple_factory_(input_tuple, dtype, constants)
|
||||
|
||||
if ctypes.sizeof(tuple_type) == 0:
|
||||
return EmptyByte
|
||||
return tuple_type
|
||||
|
||||
|
||||
def visitor_factory(node_types, node_names):
|
||||
"""
|
||||
Creates the argument type of epilogue visitor type
|
||||
|
||||
:param node_types: list of argument types under ctypes
|
||||
:param node_names: list of argument names under str
|
||||
|
||||
:return: tuple type in ctypes.Structure
|
||||
"""
|
||||
ctypes_field = []
|
||||
# Struct is used when number of nodes < 4
|
||||
# Because the Sm90VisitorImplBase has specification up to 4 nodes
|
||||
# in `include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp`
|
||||
if len(node_types) <= 4:
|
||||
for idx, node_type in enumerate(node_types):
|
||||
if ctypes.sizeof(node_type) == 0:
|
||||
# Special case for empty struct
|
||||
# 1 byte placeholder is used for correct alignment
|
||||
ctypes_field.append((node_names[idx], ctypes.c_byte))
|
||||
else:
|
||||
ctypes_field.append((node_names[idx], node_type))
|
||||
|
||||
class VisitorType(ctypes.Structure):
|
||||
_fields_ = ctypes_field
|
||||
|
||||
def __init__(self, kwargs) -> None:
|
||||
for field in self._fields_:
|
||||
fname, ftype = field
|
||||
if ftype != ctypes.c_byte:
|
||||
setattr(self, fname, ftype(kwargs))
|
||||
|
||||
# For cases with more than 4 nodes, tuple is used
|
||||
else:
|
||||
for idx, node_type in enumerate(node_types):
|
||||
ctypes_field.append((node_names[idx], node_type))
|
||||
|
||||
class VisitorType(ctypes.Structure):
|
||||
_fields_ = ctypes_field
|
||||
|
||||
def __init__(self, kwargs) -> None:
|
||||
for field in self._fields_:
|
||||
fname, ftype = field
|
||||
setattr(self, fname, ftype(kwargs))
|
||||
|
||||
return VisitorType
|
||||
99
python/cutlass_api/cutlass_api/fusion/ir/compute_nodes.py
Normal file
99
python/cutlass_api/cutlass_api/fusion/ir/compute_nodes.py
Normal file
@@ -0,0 +1,99 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Python registration for compute nodes in EVT
|
||||
"""
|
||||
|
||||
from cutlass_api.fusion.ir.node import NodeBase, ImplBase
|
||||
from cutlass_api.fusion.library import FloatRoundStyle
|
||||
|
||||
|
||||
class ComputeImplBase(ImplBase):
|
||||
"""
|
||||
Base class for compute implementation
|
||||
"""
|
||||
|
||||
def __init__(self, node) -> None:
|
||||
super().__init__(node)
|
||||
|
||||
|
||||
class ComputeImpl(ComputeImplBase):
|
||||
"""
|
||||
Implementation for Compute Node
|
||||
"""
|
||||
|
||||
def __init__(self, node) -> None:
|
||||
super().__init__(node)
|
||||
|
||||
self.fn = node.fn
|
||||
self.element_output = node.element_output
|
||||
self.element_compute = node.element_compute
|
||||
self.round_style = node.round_style
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
return True
|
||||
|
||||
|
||||
class ComputeNode(NodeBase):
|
||||
"""
|
||||
Compute Node in DAG IR
|
||||
"""
|
||||
|
||||
possible_impls = [
|
||||
ComputeImpl,
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
fn,
|
||||
element_output,
|
||||
element_compute,
|
||||
round_style=FloatRoundStyle.ToNearest,
|
||||
) -> None:
|
||||
super().__init__(name)
|
||||
self.op = "compute"
|
||||
self.fn = fn
|
||||
self.element_compute = element_compute
|
||||
self.round_style = round_style
|
||||
|
||||
def type_propagation(self, *args, **kwargs):
|
||||
"""
|
||||
Load node loads tensor under type `tensor.element` and returns an array of type `tensor.element`.
|
||||
"""
|
||||
self.element = self.element_compute
|
||||
# In general, the compute nodes have element_output = element_compute
|
||||
# In certain cases like producer of D it is overwritten by other passes
|
||||
if not hasattr(self, "element_output"):
|
||||
self.element_output = self.element
|
||||
258
python/cutlass_api/cutlass_api/fusion/ir/dag_ir.py
Normal file
258
python/cutlass_api/cutlass_api/fusion/ir/dag_ir.py
Normal file
@@ -0,0 +1,258 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
DAG IR used by Python EVT
|
||||
"""
|
||||
|
||||
import networkx as nx
|
||||
|
||||
|
||||
from cutlass_api.fusion.ir.compute_nodes import ComputeNode
|
||||
from cutlass_api.fusion.library import DataType
|
||||
from cutlass_api.fusion.ir.node import NodeBase
|
||||
from cutlass_api.fusion.library import ActivationOp
|
||||
|
||||
|
||||
class DAGIR:
|
||||
"""
|
||||
``DAGIR`` is the main data structure used in the EVT Intermediate Representation.
|
||||
It consists of a series of ``Node`` s, each representing epilogue visitor nodes.
|
||||
|
||||
In the DAGIR, ``node`` is an string of its name. ``node_meta`` is the underlying class of the node
|
||||
"""
|
||||
|
||||
def __init__(self, cc, element_compute=DataType.f32) -> None:
|
||||
# The EVT DAGIR is managed through the nextworkX Digraph class
|
||||
self._graph = nx.DiGraph()
|
||||
|
||||
self.element_compute = element_compute
|
||||
|
||||
self.reduction_names = []
|
||||
|
||||
self.cc = cc
|
||||
|
||||
self.identity_counter = 0
|
||||
|
||||
#
|
||||
# IR manipulator
|
||||
#
|
||||
|
||||
def add_node(self, meta: NodeBase):
|
||||
"""
|
||||
Add a node to dag ir
|
||||
"""
|
||||
if self.has_node(meta.name):
|
||||
raise SyntaxError(f"Variable '{meta.name}' cannot be defined twice.")
|
||||
self._graph.add_node(meta.name, meta=meta)
|
||||
|
||||
def add_edge(self, src: str, dst: str, weight: int = 0):
|
||||
"""
|
||||
Add an edge src -> dst to dag ir with weight
|
||||
"""
|
||||
if not self.has_node(src):
|
||||
raise SyntaxError(f"Variable '{src}' is undefined.")
|
||||
if not self.has_node(dst):
|
||||
raise SyntaxError(f"Variable '{dst}' is undefined.")
|
||||
|
||||
if self._graph.has_edge(src, dst):
|
||||
# The DiGraph doesn't support multiple edges between two nodes
|
||||
# We insert an identity node in such case as a workaround
|
||||
identity_name = f"autogen_identity_{self.identity_counter}"
|
||||
self.identity_counter += 1
|
||||
compute_node = ComputeNode(
|
||||
name=identity_name,
|
||||
fn=ActivationOp.Identity,
|
||||
element_output=self.element_compute,
|
||||
element_compute=self.element_compute,
|
||||
)
|
||||
self.add_node(compute_node)
|
||||
self.add_edge(src, identity_name, 0)
|
||||
self.add_edge(identity_name, dst, weight)
|
||||
else:
|
||||
self._graph.add_edge(src, dst, weight=weight)
|
||||
|
||||
def remove_node(self, node: str):
|
||||
"""
|
||||
Remove node from dag ir
|
||||
"""
|
||||
self._graph.remove_node(node)
|
||||
|
||||
def remove_edge(self, src: str, dst: str):
|
||||
"""
|
||||
Remove edge src -> dst
|
||||
"""
|
||||
self._graph.remove_edge(src, dst)
|
||||
|
||||
#
|
||||
# Helper functions for getting attrs
|
||||
#
|
||||
|
||||
def has_node(self, node: str) -> bool:
|
||||
"""
|
||||
Check if the node is in the graph
|
||||
"""
|
||||
return self._graph.has_node(node)
|
||||
|
||||
def in_degree(self, node: str):
|
||||
"""
|
||||
Get the input degree of node
|
||||
"""
|
||||
return self._graph.in_degree(node)
|
||||
|
||||
def in_edges(self, node: str):
|
||||
"""
|
||||
Get the input edges of node
|
||||
"""
|
||||
return [edge for edge in self._graph.in_edges(node)]
|
||||
|
||||
def out_degree(self, node: str):
|
||||
"""
|
||||
Get the output degree of node
|
||||
"""
|
||||
return self._graph.out_degree(node)
|
||||
|
||||
def out_edges(self, node: str):
|
||||
"""
|
||||
Get the output edges of node
|
||||
"""
|
||||
return [edge for edge in self._graph.out_edges(node)]
|
||||
|
||||
def get_node_meta(self, node: str):
|
||||
"""
|
||||
Get the meta data of the node
|
||||
"""
|
||||
return self._graph.nodes[node]["meta"]
|
||||
|
||||
def get_edge_weight(self, src, dst):
|
||||
"""
|
||||
Get the edge weight of edge src->dst
|
||||
"""
|
||||
return self._graph.get_edge_data(src, dst)["weight"]
|
||||
|
||||
#
|
||||
# High-level helper functions
|
||||
#
|
||||
|
||||
def all_reachable_nodes(self, node: str):
|
||||
"""
|
||||
Get all the nodes reachable from the current node (exclude)
|
||||
"""
|
||||
return list(nx.dfs_preorder_nodes(self._graph, source=node))
|
||||
|
||||
def get_users(self, node: str):
|
||||
"""
|
||||
Get all users of the current node
|
||||
"""
|
||||
return [edge[1] for edge in self.out_edges(node)]
|
||||
|
||||
def get_all_inputs(self, node: str):
|
||||
"""
|
||||
Get all the input nodes sorted by edge weight
|
||||
"""
|
||||
in_edges = self.in_edges(node)
|
||||
edge_weights = [self.get_edge_weight(*edge) for edge in in_edges]
|
||||
return [edge[0] for _, edge in sorted(zip(edge_weights, in_edges))]
|
||||
|
||||
def get_all_inputs_meta(self, node: str):
|
||||
"""
|
||||
Get all the input node metas sorted by edge weight
|
||||
"""
|
||||
return [
|
||||
self.get_node_meta(input_node) for input_node in self.get_all_inputs(node)
|
||||
]
|
||||
|
||||
def replace_all_uses_with(self, node1, node2):
|
||||
"""
|
||||
Replace all uses of node1 with node2
|
||||
"""
|
||||
for edge in self.out_edges(node1):
|
||||
weight = self.get_edge_weight(*edge)
|
||||
user = edge[1]
|
||||
self.add_edge(node2, user, weight)
|
||||
self.remove_edge(node1, user)
|
||||
self.remove_node(node1)
|
||||
|
||||
#
|
||||
# Node accessor
|
||||
#
|
||||
def nodes_topological_order(self):
|
||||
"""
|
||||
Get the nodes in the unique lexicographical topological order
|
||||
It generates a unique ordering of nodes by first sorting topologically
|
||||
and then additionally by sorting lexicographically.
|
||||
|
||||
Although topological_sort alone also works, this generates a unique key
|
||||
for each epilogue visitor pattern and ensures the compilation cache can be reused.
|
||||
:return: list[str]
|
||||
"""
|
||||
return list(nx.lexicographical_topological_sort(self._graph))
|
||||
|
||||
def node_metas_topological_order(self):
|
||||
"""
|
||||
Get the node metas in topological order
|
||||
:return: list[NodeBase]
|
||||
"""
|
||||
return [self.get_node_meta(node) for node in self.nodes_topological_order()]
|
||||
|
||||
@property
|
||||
def nodes(self):
|
||||
"""
|
||||
Get all nodes
|
||||
:return: list[str]
|
||||
"""
|
||||
return list(self._graph.nodes)
|
||||
|
||||
@property
|
||||
def nodes_meta(self):
|
||||
"""
|
||||
Get all node metas
|
||||
:return: list[NodeBase]
|
||||
"""
|
||||
return [data[1]["meta"] for data in self._graph.nodes.data()]
|
||||
|
||||
@property
|
||||
def edges(self):
|
||||
"""
|
||||
Get all edges
|
||||
:return: list[(str, str)]
|
||||
"""
|
||||
return list(self._graph.edges)
|
||||
|
||||
#
|
||||
# Path
|
||||
#
|
||||
def has_path(self, src: str, target: str) -> bool:
|
||||
"""
|
||||
Return True is a path exists from src to target
|
||||
"""
|
||||
return nx.has_path(self._graph, src, target)
|
||||
362
python/cutlass_api/cutlass_api/fusion/ir/layout_algorithm.py
Normal file
362
python/cutlass_api/cutlass_api/fusion/ir/layout_algorithm.py
Normal file
@@ -0,0 +1,362 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Layout algebras
|
||||
"""
|
||||
|
||||
from cutlass_api.fusion.pycute import (
|
||||
Layout,
|
||||
composition,
|
||||
make_layout,
|
||||
flatten,
|
||||
product,
|
||||
)
|
||||
|
||||
|
||||
def _infer_split(old_shape, new_shape):
|
||||
old_shape = _tuple_to_list(old_shape)
|
||||
new_shape = _tuple_to_list(new_shape)
|
||||
if len(old_shape) == 0 and len(new_shape) == 0:
|
||||
return []
|
||||
if len(old_shape) == 0:
|
||||
if product(tuple(new_shape)) != 1:
|
||||
raise ValueError("Invalid reshape size")
|
||||
else:
|
||||
return new_shape
|
||||
if len(new_shape) == 0:
|
||||
if product(tuple(old_shape)) != 1:
|
||||
raise ValueError("Invalid reshape size")
|
||||
else:
|
||||
return old_shape
|
||||
# This is done recursively by only process the last dimension at each time
|
||||
old_dim = old_shape[-1]
|
||||
new_dim = new_shape[-1]
|
||||
# Exact match
|
||||
if old_dim == new_dim:
|
||||
return _infer_split(old_shape[:-1], new_shape[:-1]) + [
|
||||
new_dim,
|
||||
]
|
||||
# Needs split
|
||||
if old_dim > new_dim and old_dim % new_dim == 0:
|
||||
residual = old_dim // new_dim
|
||||
return _infer_split(
|
||||
old_shape[:-1] + [residual],
|
||||
new_shape[:-1],
|
||||
) + [new_dim]
|
||||
# Needs merge
|
||||
if old_dim < new_dim and new_dim % old_dim == 0:
|
||||
residual = new_dim // old_dim
|
||||
return _infer_split(
|
||||
old_shape[:-1],
|
||||
new_shape[:-1] + [residual],
|
||||
) + [old_dim]
|
||||
|
||||
raise NotImplementedError(f"Unsupported split: {old_shape} -> {new_shape}")
|
||||
|
||||
|
||||
def _infer_merge(flatten_shape, shape):
|
||||
flatten_shape = _tuple_to_list(flatten_shape)
|
||||
shape = _tuple_to_list(shape)
|
||||
idx_flat = 0
|
||||
merged_shape = []
|
||||
for dim in shape:
|
||||
# Exact match
|
||||
if dim == flatten_shape[idx_flat]:
|
||||
merged_shape.append(dim)
|
||||
idx_flat += 1
|
||||
# Need group
|
||||
elif dim > flatten_shape[idx_flat] and dim % flatten_shape[idx_flat] == 0:
|
||||
residual = dim
|
||||
group = []
|
||||
while residual > 1:
|
||||
group.append(flatten_shape[idx_flat])
|
||||
residual = residual // flatten_shape[idx_flat]
|
||||
idx_flat += 1
|
||||
merged_shape.append(group)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported merge: {flatten_shape} -> {shape}")
|
||||
|
||||
return merged_shape
|
||||
|
||||
|
||||
def _list_to_tuple(nested_list):
|
||||
if isinstance(nested_list, list) or isinstance(nested_list, tuple):
|
||||
return tuple(_list_to_tuple(item) for item in nested_list)
|
||||
return nested_list
|
||||
|
||||
|
||||
def _tuple_to_list(nested_tuple):
|
||||
if isinstance(nested_tuple, list) or isinstance(nested_tuple, tuple):
|
||||
return list(_tuple_to_list(item) for item in nested_tuple)
|
||||
return nested_tuple
|
||||
|
||||
|
||||
def _reverse_tuple(nested_tuple: tuple):
|
||||
if isinstance(nested_tuple, tuple):
|
||||
return tuple([_reverse_tuple(item) for item in nested_tuple][::-1])
|
||||
return nested_tuple
|
||||
|
||||
|
||||
def _get_first_lhs_nonzero_stride(stride_list, idx):
|
||||
for i in reversed(range(idx)):
|
||||
if stride_list[i] != 0:
|
||||
return i
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def _get_first_rhs_nonzero_stride(stride_list, idx):
|
||||
for i in range(idx + 1, len(stride_list)):
|
||||
if stride_list[i] != 0:
|
||||
return i
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def reshape(layout, new_shape):
|
||||
"""
|
||||
General reshape of input layout.
|
||||
It takes two steps:
|
||||
1. split the dimensions of the old layout
|
||||
2. merge the splitted dimensions according to the new shape
|
||||
"""
|
||||
#
|
||||
# Step 1: Split the dimensions of the old layout
|
||||
#
|
||||
# 1.1 Flat old and new shape
|
||||
old_flatten_shape = list(flatten(layout.shape))
|
||||
new_flatten_shape = list(flatten(new_shape))
|
||||
|
||||
# 1.2 Infer the flatten splitted shape
|
||||
splitted_flatten_shape = _infer_split(old_flatten_shape, new_flatten_shape)
|
||||
|
||||
# 1.3 Unflat the splitted shape based on the old shape
|
||||
splited_shape = _infer_merge(splitted_flatten_shape, old_flatten_shape)
|
||||
|
||||
# 1.4 Infer the type of each split
|
||||
# If the split type is in row-major (R), the dimension list is reversed because
|
||||
# the cute::composition only support column-major split
|
||||
split_type = [] # the type of each split (ColumnMajor or RowMajor)
|
||||
permuted_splitted_shape = []
|
||||
old_flatten_stride = list(flatten(layout.stride))
|
||||
for idx, dim in enumerate(splited_shape):
|
||||
if not isinstance(dim, list):
|
||||
permuted_splitted_shape.append(dim)
|
||||
split_type.append("C")
|
||||
else:
|
||||
lhs_stride = _get_first_lhs_nonzero_stride(old_flatten_stride, idx)
|
||||
rhs_stride = _get_first_rhs_nonzero_stride(old_flatten_stride, idx)
|
||||
# Special case for single tuple
|
||||
# Use column-major by default
|
||||
if lhs_stride is None and rhs_stride is None:
|
||||
permuted_splitted_shape.append(dim)
|
||||
split_type.append("C")
|
||||
else:
|
||||
if lhs_stride is not None and rhs_stride is not None:
|
||||
# We consider shape[idx]:stride[idx]
|
||||
# Case 1: stride[idx - 1] <= stride[idx] <= stride[idx + 1]: column major
|
||||
if (
|
||||
lhs_stride <= old_flatten_stride[idx]
|
||||
and old_flatten_stride[idx] <= rhs_stride
|
||||
):
|
||||
permuted_splitted_shape.append(dim)
|
||||
split_type.append("C")
|
||||
# Case 2: stride[idx - 1] > stride[idx] > stride[idx + 1]: row major
|
||||
elif (
|
||||
lhs_stride > old_flatten_stride[idx]
|
||||
and old_flatten_stride[idx] > rhs_stride
|
||||
):
|
||||
permuted_splitted_shape.append([d for d in reversed(dim)])
|
||||
split_type.append("R")
|
||||
# Case 3: stride[idx - 1] <= stride[idx] > stride[idx + 1]: concave
|
||||
elif (
|
||||
lhs_stride <= old_flatten_stride[idx]
|
||||
and old_flatten_stride[idx] > rhs_stride
|
||||
):
|
||||
if lhs_stride >= rhs_stride:
|
||||
permuted_splitted_shape.append(dim)
|
||||
split_type.append("C")
|
||||
else:
|
||||
permuted_splitted_shape.append([d for d in reversed(dim)])
|
||||
split_type.append("R")
|
||||
# Case 4: stride[idx - 1] > stride[idx] <= stride[idx + 1]: concave
|
||||
elif (
|
||||
lhs_stride > old_flatten_stride[idx]
|
||||
and old_flatten_stride[idx] <= rhs_stride
|
||||
):
|
||||
if lhs_stride >= rhs_stride:
|
||||
permuted_splitted_shape.append(dim)
|
||||
split_type.append("C")
|
||||
else:
|
||||
permuted_splitted_shape.append([d for d in reversed(dim)])
|
||||
split_type.append("R")
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
elif lhs_stride is None:
|
||||
# Case 1: dim's stride < dim+1's stride, expand in column major
|
||||
if old_flatten_stride[idx] > rhs_stride:
|
||||
permuted_splitted_shape.append([d for d in reversed(dim)])
|
||||
split_type.append("R")
|
||||
else:
|
||||
permuted_splitted_shape.append(dim)
|
||||
split_type.append("C")
|
||||
else:
|
||||
# Case 1: dim's stride > dim-1's stride
|
||||
if old_flatten_stride[idx] < lhs_stride:
|
||||
permuted_splitted_shape.append([d for d in reversed(dim)])
|
||||
split_type.append("R")
|
||||
else:
|
||||
permuted_splitted_shape.append(dim)
|
||||
split_type.append("C")
|
||||
|
||||
# 1.4 Generate the splitted layout
|
||||
permuted_splitted_layout = composition(
|
||||
layout, Layout(_list_to_tuple(permuted_splitted_shape))
|
||||
)
|
||||
|
||||
# 1.5 Reverse the permutation in 1.4 before merge
|
||||
splitted_shape = []
|
||||
splitted_stride = []
|
||||
for shape_dim, stride_dim, type in zip(
|
||||
permuted_splitted_layout.shape, permuted_splitted_layout.stride, split_type
|
||||
):
|
||||
if type == "C":
|
||||
splitted_shape.append(shape_dim)
|
||||
splitted_stride.append(stride_dim)
|
||||
else:
|
||||
splitted_shape.append(tuple([d for d in reversed(shape_dim)]))
|
||||
splitted_stride.append(tuple([d for d in reversed(stride_dim)]))
|
||||
splitted_layout = Layout(tuple(splitted_shape), tuple(splitted_stride))
|
||||
|
||||
#
|
||||
# Step 2: Merge the splitted dimensions according to the new shape
|
||||
#
|
||||
# 2.1 Merge layout
|
||||
merged_layout = composition(splitted_layout, Layout(new_shape))
|
||||
|
||||
# 2.2 Cleaning up
|
||||
output_layout = composition(merged_layout, Layout(new_shape))
|
||||
return output_layout
|
||||
|
||||
|
||||
def permutation(layout, permutation):
|
||||
"""
|
||||
Permute the layout
|
||||
"""
|
||||
new_shape = tuple([layout.shape[idx] for idx in permutation])
|
||||
new_stride = tuple([layout.stride[idx] for idx in permutation])
|
||||
return Layout(new_shape, new_stride)
|
||||
|
||||
|
||||
def _broadcast(layout, new_shape):
|
||||
if len(layout) == 1 and isinstance(new_shape, int):
|
||||
old_dim = layout.shape
|
||||
old_stride = layout.stride
|
||||
new_dim = new_shape
|
||||
if old_dim == new_dim:
|
||||
return Layout(old_dim, old_stride)
|
||||
elif old_dim == 1:
|
||||
return Layout(new_dim, 0)
|
||||
else:
|
||||
raise NotImplementedError(f"Invalid Broadcast: {old_dim} -> {new_dim}")
|
||||
|
||||
# Align the dimensions
|
||||
old_shape = layout.shape
|
||||
if isinstance(old_shape, int):
|
||||
old_shape = (old_shape,)
|
||||
sub_layouts = [
|
||||
layout,
|
||||
]
|
||||
else:
|
||||
sub_layouts = [sub_layout for sub_layout in layout]
|
||||
rhs_broadcast_layouts = [Layout(1, 0)] * (len(new_shape) - len(old_shape))
|
||||
# Get the broadcasted layout
|
||||
broadcast_layouts = []
|
||||
try:
|
||||
layout = make_layout(*sub_layouts, *rhs_broadcast_layouts)
|
||||
broadcast_layouts = []
|
||||
for idx, sub_layout in enumerate(layout):
|
||||
broadcast_layouts.append(_broadcast(sub_layout, new_shape[idx]))
|
||||
except NotImplementedError:
|
||||
layout = make_layout(*rhs_broadcast_layouts, *sub_layouts)
|
||||
for idx, sub_layout in enumerate(layout):
|
||||
broadcast_layouts.append(_broadcast(sub_layout, new_shape[idx]))
|
||||
return make_layout(*broadcast_layouts)
|
||||
|
||||
|
||||
def broadcast(layout, new_shape):
|
||||
"""
|
||||
Broadcast the new layout based on the input shape
|
||||
The broadcasted shape equals to the new shape
|
||||
The stride of broadcasted dimensions are 0
|
||||
"""
|
||||
return _broadcast(layout, new_shape)
|
||||
|
||||
|
||||
def debroadcast(layout, dims):
|
||||
"""
|
||||
Squeeze the 0-stride
|
||||
"""
|
||||
for dim in dims:
|
||||
if layout.stride[dim] != 0:
|
||||
raise ValueError(
|
||||
f"Dim{dim} cannot be debroadcasted as it has stride {layout.stride[dim]}"
|
||||
)
|
||||
new_shape = tuple([s for idx, s in enumerate(layout.shape) if idx not in dims])
|
||||
new_stride = tuple([s for idx, s in enumerate(layout.stride) if idx not in dims])
|
||||
return Layout(new_shape, new_stride)
|
||||
|
||||
|
||||
def canonicalization_(shapes, strides):
|
||||
if isinstance(shapes, tuple):
|
||||
c_shapes = []
|
||||
c_strides = []
|
||||
for shape, stride in zip(shapes, strides):
|
||||
c_shape, c_stride = canonicalization_(shape, stride)
|
||||
c_shapes.append(c_shape)
|
||||
c_strides.append(c_stride)
|
||||
return tuple(c_shapes), tuple(c_strides)
|
||||
else:
|
||||
if shapes == 1:
|
||||
return 1, 0
|
||||
else:
|
||||
return shapes, strides
|
||||
|
||||
|
||||
def canonicalization(layout):
|
||||
"""
|
||||
Canonicalize the input layout
|
||||
1. set the stride of shape "1" to 0
|
||||
"""
|
||||
new_shape, new_stride = canonicalization_(layout.shape, layout.stride)
|
||||
return Layout(new_shape, new_stride)
|
||||
351
python/cutlass_api/cutlass_api/fusion/ir/layout_nodes.py
Normal file
351
python/cutlass_api/cutlass_api/fusion/ir/layout_nodes.py
Normal file
@@ -0,0 +1,351 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Layout manipulation nodes and implementations
|
||||
|
||||
The layout Nodes change the layout of intermediate nodes in epilogue visitor graph
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
from cutlass_api.fusion.library import LayoutType
|
||||
from cutlass_api.fusion.ir.layout_algorithm import _list_to_tuple, _tuple_to_list
|
||||
from cutlass_api.fusion.ir.node import NodeBase
|
||||
from cutlass_api.fusion.ir.tensor import Tensor
|
||||
from cutlass_api.fusion.pycute import product, flatten
|
||||
|
||||
|
||||
class PermutationImpl:
|
||||
"""
|
||||
Detailed implementation and helper functions for permutation
|
||||
"""
|
||||
|
||||
def __init__(self, node) -> None:
|
||||
assert "indices" in node.kwargs.keys()
|
||||
self.indices = list(node.kwargs["indices"])
|
||||
self.inverse_indices = self.get_inverse_indices(self.indices)
|
||||
|
||||
def get_inverse_impl(self):
|
||||
inverse_impl = deepcopy(self)
|
||||
inverse_impl.indices = self.inverse_indices
|
||||
inverse_impl.inverse_indices = self.indices
|
||||
return inverse_impl
|
||||
|
||||
def update(self, shape):
|
||||
num_dim = len(shape)
|
||||
indices = self.indices
|
||||
num_old_dim = len(indices)
|
||||
# Add offset
|
||||
for i, idx in enumerate(indices):
|
||||
indices[i] = idx + num_dim - num_old_dim
|
||||
# Add broadcast dims
|
||||
for i in range(num_dim - num_old_dim):
|
||||
indices = [i] + indices
|
||||
|
||||
self.indices = indices
|
||||
self.inverse_indices = self.get_inverse_indices(self.indices)
|
||||
|
||||
def get_inverse_indices(self, indices):
|
||||
"""
|
||||
Get the indices for inverse permutation
|
||||
"""
|
||||
num_dim = len(indices)
|
||||
inverse_indices = [0] * num_dim
|
||||
for i in range(num_dim):
|
||||
inverse_indices[indices[i]] = i
|
||||
return inverse_indices
|
||||
|
||||
def shape_propagation(self, input_node_meta):
|
||||
input_shape = input_node_meta.tensor.shape
|
||||
output_shape = tuple([input_shape[idx] for idx in self.indices])
|
||||
return output_shape
|
||||
|
||||
def broadcast(self, shape, node_meta: NodeBase):
|
||||
"""
|
||||
Broadcast the inputs based on current shape
|
||||
"""
|
||||
self.update(shape)
|
||||
inverse_shape = tuple([shape[idx] for idx in self.inverse_indices])
|
||||
node_meta.tensor.broadcast(inverse_shape)
|
||||
|
||||
def apply_to_user(self, usr_meta: NodeBase):
|
||||
"""
|
||||
Propagate the permutation to the users of the current nodes
|
||||
"""
|
||||
usr_meta.tensor.permute(self.inverse_indices)
|
||||
if hasattr(usr_meta, "store_tensor"):
|
||||
if usr_meta.store_tensor is not None:
|
||||
usr_meta.store_tensor.permute(self.inverse_indices)
|
||||
|
||||
def apply_to_input(self, input_meta: NodeBase):
|
||||
"""
|
||||
Propagate the permutation to inputs of the current nodes
|
||||
"""
|
||||
input_meta.tensor.permute(self.indices)
|
||||
if hasattr(input_meta, "store_tensor"):
|
||||
if input_meta.store_tensor is not None:
|
||||
input_meta.store_tensor.permute(self.indices)
|
||||
|
||||
|
||||
class ReshapeImpl:
|
||||
"""
|
||||
Detailed implementation and helper functions for reshape
|
||||
"""
|
||||
|
||||
def __init__(self, node) -> None:
|
||||
self.node = node
|
||||
assert "new_shape" in node.kwargs.keys()
|
||||
self.output_shape = _list_to_tuple(node.kwargs["new_shape"])
|
||||
|
||||
def get_inverse_impl(self):
|
||||
inverse_impl = deepcopy(self)
|
||||
inverse_impl.output_shape = self.input_shape
|
||||
inverse_impl.input_shape = self.output_shape
|
||||
return inverse_impl
|
||||
|
||||
def shape_propagation(self, input_node_meta):
|
||||
self.input_shape = input_node_meta.tensor.shape
|
||||
return _list_to_tuple(self.output_shape)
|
||||
|
||||
def broadcast(self, shape, node_meta: NodeBase):
|
||||
"""
|
||||
Broadcast the inputs based on current shape.
|
||||
"""
|
||||
# Step 1: infer split
|
||||
flatten_split_shape = self.infer_split(
|
||||
flatten(self.input_shape), flatten(self.output_shape)
|
||||
)
|
||||
split_input_shape = self.infer_merge(flatten_split_shape, self.input_shape)
|
||||
split_output_shape = self.infer_merge(flatten_split_shape, self.output_shape)
|
||||
|
||||
# broadcast shape -> split_output_shape -> flatten_split_shape
|
||||
if len(shape) - len(split_output_shape) > 0:
|
||||
for _ in range(len(shape) - len(split_output_shape)):
|
||||
split_output_shape = [1] + split_output_shape
|
||||
flatten_split_shape = [1] + flatten_split_shape
|
||||
split_input_shape = [1] + split_input_shape
|
||||
broadcast_factor = []
|
||||
for dim, old_dim in zip(shape, split_output_shape):
|
||||
if not isinstance(dim, list):
|
||||
dim = [dim]
|
||||
if not isinstance(old_dim, list):
|
||||
old_dim = [old_dim]
|
||||
if product(tuple(dim)) == product(tuple(old_dim)):
|
||||
broadcast_factor += [1] * len(old_dim)
|
||||
elif product(tuple(old_dim)) == 1:
|
||||
assert len(dim) == 1
|
||||
broadcast_factor.append(dim[0])
|
||||
else:
|
||||
raise NotImplementedError(f"Invalid Broadcast: {old_dim} -> {dim}")
|
||||
|
||||
# flatten_split_shape -> split_input_shape
|
||||
factor_idx = 0
|
||||
broadcast_split_input_shape = []
|
||||
for dim in split_input_shape:
|
||||
if isinstance(dim, list):
|
||||
new_dim = []
|
||||
for d in dim:
|
||||
new_dim.append(d * broadcast_factor[factor_idx])
|
||||
factor_idx += 1
|
||||
broadcast_split_input_shape.append(new_dim)
|
||||
else:
|
||||
broadcast_split_input_shape.append(dim * broadcast_factor[factor_idx])
|
||||
factor_idx += 1
|
||||
broadcast_split_input_shape = _list_to_tuple(broadcast_split_input_shape)
|
||||
node_meta.tensor.reshape(_list_to_tuple(split_input_shape))
|
||||
node_meta.tensor.broadcast(broadcast_split_input_shape)
|
||||
# Last reshape op to clean up
|
||||
broadcast_input_shape = tuple(
|
||||
[product(dim) for dim in broadcast_split_input_shape]
|
||||
)
|
||||
node_meta.tensor.reshape(broadcast_input_shape)
|
||||
# Update the input shape and output shape
|
||||
self.input_shape = _list_to_tuple(node_meta.tensor.shape)
|
||||
self.output_shape = _list_to_tuple(shape)
|
||||
|
||||
def apply_to_user(self, user_meta: NodeBase):
|
||||
"""
|
||||
Propagate the reshape to user nodes
|
||||
"""
|
||||
user_meta.tensor.reshape(tuple(self.input_shape))
|
||||
if hasattr(user_meta, "store_tensor"):
|
||||
if user_meta.store_tensor is not None:
|
||||
user_meta.store_tensor.reshape(tuple(self.input_shape))
|
||||
|
||||
def apply_to_input(self, input_meta: NodeBase):
|
||||
"""
|
||||
Propagate the reshape to input nodes
|
||||
"""
|
||||
input_meta.tensor.reshape(tuple(self.output_shape))
|
||||
if hasattr(input_meta, "store_tensor"):
|
||||
if input_meta.store_tensor is not None:
|
||||
input_meta.store_tensor.reshape(tuple(self.output_shape))
|
||||
|
||||
#
|
||||
# Helper functions
|
||||
#
|
||||
|
||||
def infer_split(self, input_shape, output_shape):
|
||||
"""
|
||||
Infer the flatten splitted shape that can be merged to both input_shape and output_shape
|
||||
"""
|
||||
input_shape = _tuple_to_list(input_shape)
|
||||
output_shape = _tuple_to_list(output_shape)
|
||||
if len(input_shape) == 0 and len(output_shape) == 0:
|
||||
return []
|
||||
if len(input_shape) == 0:
|
||||
if product(tuple(output_shape)) != 1:
|
||||
raise ValueError("Invalid reshape size")
|
||||
else:
|
||||
return output_shape
|
||||
if len(output_shape) == 0:
|
||||
if product(tuple(input_shape)) != 1:
|
||||
raise ValueError("Invalid reshape size")
|
||||
else:
|
||||
return input_shape
|
||||
# This is done recursively by only process the last dimension at each time
|
||||
old_dim = input_shape[-1]
|
||||
new_dim = output_shape[-1]
|
||||
# Exact match
|
||||
if old_dim == new_dim:
|
||||
return self.infer_split(input_shape[:-1], output_shape[:-1]) + [new_dim]
|
||||
# Needs split
|
||||
if old_dim > new_dim and old_dim % new_dim == 0:
|
||||
residual = old_dim // new_dim
|
||||
return self.infer_split(
|
||||
input_shape[:-1] + [residual],
|
||||
output_shape[:-1],
|
||||
) + [new_dim]
|
||||
# Needs merge
|
||||
if old_dim < new_dim and new_dim % old_dim == 0:
|
||||
residual = new_dim // old_dim
|
||||
return self.infer_split(
|
||||
input_shape[:-1],
|
||||
output_shape[:-1] + [residual],
|
||||
) + [old_dim]
|
||||
|
||||
raise NotImplementedError(f"Unsupported split: {input_shape} -> {output_shape}")
|
||||
|
||||
def infer_merge(self, flatten_shape, shape):
|
||||
flatten_shape = _tuple_to_list(flatten_shape)
|
||||
shape = _tuple_to_list(shape)
|
||||
idx_flat = len(flatten_shape) - 1
|
||||
merged_shape = []
|
||||
for dim in reversed(shape):
|
||||
# Exact match
|
||||
if dim == flatten_shape[idx_flat]:
|
||||
merged_shape.append(dim)
|
||||
idx_flat -= 1
|
||||
# need group
|
||||
elif dim > flatten_shape[idx_flat] and dim % flatten_shape[idx_flat] == 0:
|
||||
residual = dim
|
||||
group = []
|
||||
while residual > 1:
|
||||
group.append(flatten_shape[idx_flat])
|
||||
residual = residual // flatten_shape[idx_flat]
|
||||
idx_flat -= 1
|
||||
merged_shape.append(group[::-1])
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported merge: {flatten_shape} -> {shape}"
|
||||
)
|
||||
|
||||
return merged_shape[::-1]
|
||||
|
||||
|
||||
class LayoutNode(NodeBase):
|
||||
"""
|
||||
Layout manipulation nodes
|
||||
"""
|
||||
|
||||
fn_to_impl = {
|
||||
"permute": PermutationImpl,
|
||||
"reshape": ReshapeImpl,
|
||||
}
|
||||
|
||||
def __init__(self, name: str, fn, kwargs: dict) -> None:
|
||||
super().__init__(name)
|
||||
self.op = "layout"
|
||||
self.fn = fn
|
||||
self.kwargs = kwargs
|
||||
self.underlying_impl = self.fn_to_impl[self.fn.__name__](self)
|
||||
|
||||
def get_inverse_node(self):
|
||||
inverse_node = deepcopy(self)
|
||||
inverse_node.underlying_impl = self.underlying_impl.get_inverse_impl()
|
||||
return inverse_node
|
||||
|
||||
def shape_propagation(self, input_node_metas):
|
||||
if self._tensor is not None:
|
||||
return
|
||||
assert len(input_node_metas) == 1, "Layout node can only have one input node"
|
||||
|
||||
output_shape = self.underlying_impl.shape_propagation(input_node_metas[0])
|
||||
|
||||
self._tensor = Tensor(
|
||||
element=self.element_output,
|
||||
shape=output_shape,
|
||||
layout_tag=LayoutType.RowMajor,
|
||||
)
|
||||
|
||||
return super().shape_propagation(input_node_metas)
|
||||
|
||||
def type_propagation(self, input_node_metas: "list[NodeBase]"):
|
||||
"""
|
||||
The store nodes has element_output = element_input
|
||||
"""
|
||||
assert len(input_node_metas) == 1, "Layout node can only have one input node"
|
||||
self.element_output = input_node_metas[0].element_output
|
||||
|
||||
def broadcast_propagation(self, input_node_metas: "list[NodeBase]"):
|
||||
"""
|
||||
Propagate the broadcast in the reversed topological order
|
||||
"""
|
||||
if self.tensor is None:
|
||||
raise RuntimeError(f"The tensor of node {self.name} is unknown.")
|
||||
shape = self.tensor.shape
|
||||
|
||||
for child in input_node_metas:
|
||||
self.underlying_impl.broadcast(shape, child)
|
||||
|
||||
def apply_to_user(self, usr_meta: NodeBase):
|
||||
"""
|
||||
Propagate the permutation to user nodes
|
||||
"""
|
||||
self.underlying_impl.apply_to_user(usr_meta)
|
||||
|
||||
def apply_to_input(self, input_meta: NodeBase):
|
||||
"""
|
||||
Propagate the permutation to input nodes
|
||||
"""
|
||||
self.underlying_impl.apply_to_input(input_meta)
|
||||
312
python/cutlass_api/cutlass_api/fusion/ir/load_nodes.py
Normal file
312
python/cutlass_api/cutlass_api/fusion/ir/load_nodes.py
Normal file
@@ -0,0 +1,312 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Load nodes and implementations
|
||||
"""
|
||||
|
||||
import ctypes
|
||||
|
||||
from cutlass_api.fusion.ir.c_types import dtype2ctype, to_ctype_value, tuple_factory
|
||||
from cutlass_api.fusion.ir.node import NodeBase, ImplBase
|
||||
|
||||
|
||||
class LoadImplBase(ImplBase):
|
||||
"""
|
||||
Base class for load node implementations
|
||||
"""
|
||||
|
||||
reserved_names = ["accum", "C"]
|
||||
|
||||
def __init__(self, node) -> None:
|
||||
super().__init__(node)
|
||||
self.element = node.element
|
||||
self.element_output = node.element_output
|
||||
self.stride = node.tensor.stride
|
||||
|
||||
|
||||
class AccumulatorImpl(LoadImplBase):
|
||||
"""
|
||||
Accumulator node implementation
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
return node.name == "accum" and node.tensor.shape == problem_size
|
||||
|
||||
|
||||
class LoadSrcImpl(LoadImplBase):
|
||||
"""
|
||||
Load C implementation
|
||||
"""
|
||||
|
||||
@property
|
||||
def name_camel(self) -> str:
|
||||
return "TensorC"
|
||||
|
||||
@property
|
||||
def argument_type_c(self):
|
||||
stride_mnl = self.get_stride_mnl()
|
||||
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
|
||||
|
||||
class _Argument(ctypes.Structure):
|
||||
_fields_ = [("ptr_C", ctypes.c_void_p), ("stride_C", tuple_type)]
|
||||
|
||||
def __init__(self, ptr) -> None:
|
||||
self.ptr_C = ptr
|
||||
self.stride_C = tuple_type(stride_mnl)
|
||||
|
||||
return _Argument
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
return node.name == "C" and node.tensor.shape == problem_size
|
||||
|
||||
|
||||
class AuxLoadImpl(LoadImplBase):
|
||||
"""
|
||||
Load arbitrary tensor
|
||||
"""
|
||||
|
||||
@property
|
||||
def argument_type(self):
|
||||
stride_mnl = self.get_stride_mnl()
|
||||
name = self.name
|
||||
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
|
||||
element_type = self.element
|
||||
|
||||
class _Argument(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("ptr_aux", ctypes.c_void_p),
|
||||
("null_default", dtype2ctype[element_type]),
|
||||
("dAux", tuple_type),
|
||||
]
|
||||
|
||||
def __init__(self, kwargs) -> None:
|
||||
ptr = kwargs[name]
|
||||
self.ptr_aux = ptr
|
||||
self.null_default = to_ctype_value(0, element_type)
|
||||
self.dAux = tuple_type(stride_mnl)
|
||||
|
||||
return _Argument
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
if node.name in LoadImplBase.reserved_names:
|
||||
return False
|
||||
strideMN = node.tensor.stride[-2:]
|
||||
if 1 not in strideMN or 0 in strideMN:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class RowBroadcastImpl(LoadImplBase):
|
||||
"""
|
||||
Broadcast a row vector
|
||||
"""
|
||||
|
||||
def __init__(self, node) -> None:
|
||||
super().__init__(node)
|
||||
self.stride_dtype = "int"
|
||||
|
||||
@property
|
||||
def argument_type(self):
|
||||
stride_mnl = self.get_stride_mnl()
|
||||
name = self.name
|
||||
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
|
||||
element_type = self.element
|
||||
|
||||
class _Argument(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("ptr_row", ctypes.c_void_p),
|
||||
("null_default", dtype2ctype[element_type]),
|
||||
("dRow", tuple_type),
|
||||
]
|
||||
|
||||
def __init__(self, kwargs) -> None:
|
||||
ptr = kwargs[name]
|
||||
self.ptr_row = ptr
|
||||
self.null_default = to_ctype_value(0, element_type)
|
||||
self.dRow = tuple_type(stride_mnl)
|
||||
|
||||
return _Argument
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
if node.name in LoadImplBase.reserved_names:
|
||||
return False
|
||||
|
||||
strideMN = node.tensor.stride[-2:]
|
||||
if strideMN == (0, 1):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class ColumnBroadcastImpl(LoadImplBase):
|
||||
"""
|
||||
Broadcast a column vector
|
||||
"""
|
||||
|
||||
def __init__(self, node) -> None:
|
||||
super().__init__(node)
|
||||
self.stride_dtype = "int"
|
||||
|
||||
@property
|
||||
def argument_type(self):
|
||||
stride_mnl = self.get_stride_mnl()
|
||||
name = self.name
|
||||
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
|
||||
element_type = self.element
|
||||
|
||||
class _Argument(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("ptr_col", ctypes.c_void_p),
|
||||
("null_default", dtype2ctype[element_type]),
|
||||
("dCol", tuple_type),
|
||||
]
|
||||
|
||||
def __init__(self, kwargs) -> None:
|
||||
ptr = kwargs[name]
|
||||
self.ptr_col = int(ptr)
|
||||
self.null_default = to_ctype_value(0, element_type)
|
||||
self.dCol = tuple_type(stride_mnl)
|
||||
|
||||
return _Argument
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
if node.name in LoadImplBase.reserved_names:
|
||||
return False
|
||||
|
||||
strideMN = node.tensor.stride[-2:]
|
||||
if strideMN == (1, 0):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class ScalarBroadcastImpl(LoadImplBase):
|
||||
"""
|
||||
Broadcast a scalar
|
||||
"""
|
||||
|
||||
def __init__(self, node) -> None:
|
||||
super().__init__(node)
|
||||
self.stride_dtype = "int"
|
||||
|
||||
@property
|
||||
def argument_type(self):
|
||||
stride_mnl = self.get_stride_mnl()
|
||||
name = self.name
|
||||
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
|
||||
element_type = self.element
|
||||
|
||||
if self.tensor.is_constant:
|
||||
value = self.tensor.value
|
||||
|
||||
class _Argument(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("scalars", dtype2ctype[element_type]),
|
||||
("scalar_ptrs", ctypes.c_void_p),
|
||||
("dScalar", tuple_type),
|
||||
]
|
||||
|
||||
def __init__(self, kwargs) -> None:
|
||||
self.scalars = to_ctype_value(value, element_type)
|
||||
self.scalar_ptrs = 0
|
||||
self.dScalar = tuple_type(stride_mnl)
|
||||
|
||||
else:
|
||||
|
||||
class _Argument(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("scalars", dtype2ctype[element_type]),
|
||||
("scalar_ptrs", ctypes.c_void_p),
|
||||
("dScalar", tuple_type),
|
||||
]
|
||||
|
||||
def __init__(self, kwargs) -> None:
|
||||
scalar_or_ptr = kwargs[name]
|
||||
if isinstance(scalar_or_ptr, float):
|
||||
self.scalars = to_ctype_value(scalar_or_ptr, element_type)
|
||||
self.scalar_ptrs = 0
|
||||
else:
|
||||
self.scalar_ptrs = int(scalar_or_ptr)
|
||||
|
||||
self.dScalar = tuple_type(stride_mnl)
|
||||
|
||||
return _Argument
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
if node.name in LoadImplBase.reserved_names:
|
||||
return False
|
||||
|
||||
strideMN = node.tensor.stride[-2:]
|
||||
if strideMN == (0, 0):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class LoadNode(NodeBase):
|
||||
"""
|
||||
Load Node
|
||||
"""
|
||||
|
||||
cnt = 0
|
||||
possible_impls = [
|
||||
AccumulatorImpl,
|
||||
LoadSrcImpl,
|
||||
AuxLoadImpl,
|
||||
RowBroadcastImpl,
|
||||
ColumnBroadcastImpl,
|
||||
ScalarBroadcastImpl,
|
||||
]
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
if name is None:
|
||||
name = f"load{LoadNode.cnt}"
|
||||
LoadNode.cnt += 1
|
||||
super().__init__(name)
|
||||
self.op = "load"
|
||||
|
||||
def type_propagation(self, *args, **kwargs):
|
||||
"""
|
||||
Load node loads tensor under type `tensor.element` and returns an array of type `tensor.element`.
|
||||
"""
|
||||
if self.tensor is None:
|
||||
raise RuntimeError(f"The tensor of node {self.name} is unknown.")
|
||||
|
||||
self.element = self.tensor.element
|
||||
self.element_output = self.tensor.element
|
||||
330
python/cutlass_api/cutlass_api/fusion/ir/node.py
Normal file
330
python/cutlass_api/cutlass_api/fusion/ir/node.py
Normal file
@@ -0,0 +1,330 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Base & visitor classes of DAGIR Nodes
|
||||
"""
|
||||
|
||||
import ctypes
|
||||
from re import sub
|
||||
|
||||
from cutlass_api.fusion.library import LayoutType
|
||||
from cutlass_api.fusion.ir.layout_algorithm import _list_to_tuple, _reverse_tuple
|
||||
from cutlass_api.fusion.ir.tensor import Tensor
|
||||
|
||||
|
||||
class TupleEmitter:
|
||||
"""
|
||||
Emit the cute tuple to C++ code
|
||||
"""
|
||||
|
||||
def __init__(self, stride_dtype):
|
||||
self.stride_dtype = stride_dtype
|
||||
|
||||
def emit(self, py_tuple):
|
||||
if isinstance(py_tuple, int):
|
||||
if py_tuple in [0, 1]:
|
||||
return f"cute::Int<{py_tuple}>"
|
||||
else:
|
||||
return f"{self.stride_dtype}"
|
||||
elif isinstance(py_tuple, tuple):
|
||||
decl = "cute::Stride<"
|
||||
for item in py_tuple:
|
||||
decl += self.emit(item) + ", "
|
||||
return decl[:-2] + ">"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"TupleEmitter.emit only accepts tuple or int, got {type(py_tuple).__name__}"
|
||||
)
|
||||
|
||||
|
||||
class ImplBase:
|
||||
"""
|
||||
Base class for Node Implementation
|
||||
"""
|
||||
|
||||
def __init__(self, node) -> None:
|
||||
self.node = node
|
||||
self.name = node.name
|
||||
self.tensor = node.tensor
|
||||
self._type_decl = None
|
||||
self.tuple_emitter = TupleEmitter("int64_t")
|
||||
|
||||
@property
|
||||
def stride_dtype(self):
|
||||
return self.tuple_emitter.stride_dtype
|
||||
|
||||
@stride_dtype.setter
|
||||
def stride_dtype(self, stride_dtype):
|
||||
self.tuple_emitter.stride_dtype = stride_dtype
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
"""
|
||||
Match function used in get_underlying_impl
|
||||
"""
|
||||
raise NotImplementedError("The `match` function is not defined.")
|
||||
|
||||
@property
|
||||
def argument_type(self):
|
||||
"""
|
||||
Default class for Argument Type
|
||||
"""
|
||||
|
||||
class _Argument(ctypes.Structure):
|
||||
_fields_ = []
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
return _Argument
|
||||
|
||||
@property
|
||||
def name_camel(self) -> str:
|
||||
"""
|
||||
Return the CamelCase name.
|
||||
"""
|
||||
return sub(r"(_|-)+", " ", self.name).title().replace(" ", "")
|
||||
|
||||
@property
|
||||
def stride_mnl(self):
|
||||
"""
|
||||
Typename StrideMNL
|
||||
"""
|
||||
stride = _list_to_tuple(
|
||||
[self.stride[-2], self.stride[-1]]
|
||||
+ list(_reverse_tuple(tuple(self.stride[:-2])))
|
||||
)
|
||||
return self.tuple_emitter.emit(stride)
|
||||
|
||||
def get_non_constant_stride(self, py_tuple):
|
||||
if isinstance(py_tuple, int):
|
||||
if py_tuple not in [0, 1]:
|
||||
return py_tuple
|
||||
else:
|
||||
return None
|
||||
non_constant_stride = []
|
||||
for item in py_tuple:
|
||||
item_out = self.get_non_constant_stride(item)
|
||||
if item_out:
|
||||
non_constant_stride.append(item_out)
|
||||
return tuple(non_constant_stride)
|
||||
|
||||
def get_stride_mnl(self):
|
||||
"""
|
||||
Get the non-zero stride mnl. This is used in argument construction
|
||||
"""
|
||||
stride = _list_to_tuple(
|
||||
[self.stride[-2], self.stride[-1]]
|
||||
+ list(_reverse_tuple(tuple(self.stride[:-2])))
|
||||
)
|
||||
return stride
|
||||
|
||||
def get_smem_size(self, *args, **kwargs):
|
||||
"""
|
||||
Get the shared memory size and alignment of current node
|
||||
"""
|
||||
return (0, 1)
|
||||
|
||||
|
||||
class NoOpImpl(ImplBase):
|
||||
"""
|
||||
The NoOpImpl does nothing but forward its input to users
|
||||
"""
|
||||
|
||||
def __init__(self, node) -> None:
|
||||
super().__init__(node)
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
if node.op == "store":
|
||||
# Store that is not output is a No OP
|
||||
return not node.is_output
|
||||
|
||||
|
||||
class NodeBase:
|
||||
"""
|
||||
Base class of DAG Node
|
||||
"""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
self.name = name
|
||||
self.underlying_impl = None
|
||||
|
||||
self._tensor = None
|
||||
|
||||
# Whether the node is disabled for emit
|
||||
self.disabled = False
|
||||
|
||||
@property
|
||||
def name_camel(self) -> str:
|
||||
"""
|
||||
Return the CamelCase name.
|
||||
"""
|
||||
return self.underlying_impl.name_camel
|
||||
|
||||
@property
|
||||
def tensor(self) -> Tensor:
|
||||
"""
|
||||
Return the output tensor (concept: cutlass_api.fusion.ir.tensor)
|
||||
"""
|
||||
return self._tensor
|
||||
|
||||
@tensor.setter
|
||||
def tensor(self, kwargs):
|
||||
"""
|
||||
Setting the tensor
|
||||
"""
|
||||
self._tensor = Tensor(**kwargs)
|
||||
|
||||
#
|
||||
# Helper functions for type/shape propagation
|
||||
#
|
||||
|
||||
def shape_propagation(self, input_node_metas):
|
||||
"""
|
||||
Infer shape from input nodes
|
||||
General Broadcasting Rules from NumPy
|
||||
When operating on two arrays, we compare their shapes element-wise.
|
||||
It starts with the trailing (i.e. rightmost) dimension and works its
|
||||
way left. Two dimensions are compatible when
|
||||
1. they are equal
|
||||
2. one of them is 1
|
||||
"""
|
||||
if self._tensor is not None:
|
||||
return
|
||||
|
||||
shape = None
|
||||
for src in input_node_metas:
|
||||
src_shape = src.tensor.shape
|
||||
if shape is None:
|
||||
shape = src_shape
|
||||
else:
|
||||
len_difference = len(shape) - len(src_shape)
|
||||
if len_difference > 0:
|
||||
for _ in range(len_difference):
|
||||
src_shape = [1] + list(src_shape)
|
||||
elif len_difference < 0:
|
||||
for _ in range(-len_difference):
|
||||
shape = [1] + list(shape)
|
||||
broadcasted_shape = []
|
||||
# Infer broadcast shape
|
||||
for shape_dim, src_dim in zip(reversed(shape), reversed(src_shape)):
|
||||
if shape_dim == 1:
|
||||
broadcasted_shape = [src_dim] + list(broadcasted_shape)
|
||||
elif src_dim == 1:
|
||||
broadcasted_shape = [shape_dim] + list(broadcasted_shape)
|
||||
elif shape_dim == src_dim:
|
||||
broadcasted_shape = [shape_dim] + list(broadcasted_shape)
|
||||
else:
|
||||
error_msg = "Dimension mismatch between "
|
||||
for src_ in input_node_metas:
|
||||
error_msg += f"{src_.name}{src_.tensor.shape}, "
|
||||
error_msg = error_msg[:-2] + "."
|
||||
raise RuntimeError(error_msg)
|
||||
shape = tuple(broadcasted_shape)
|
||||
|
||||
self._tensor = Tensor(
|
||||
element=self.element_output, shape=shape, layout_tag=LayoutType.RowMajor
|
||||
)
|
||||
|
||||
def type_propagation(self, *args, **kwargs):
|
||||
"""
|
||||
Each node is associated with two data types: `element` and `element_output`.
|
||||
The `element_output` is the type of return array of the node. The `element`
|
||||
has specific meaning for different node types.
|
||||
* Load Node: data type of tensor in gmem
|
||||
* Compute Node: element compute
|
||||
* Store Node: data type of tensor in gmem
|
||||
This function must be overloaded in the derived classes
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f"Function `type_propagation` is not overloaded in {self.__class__.__name__}"
|
||||
)
|
||||
|
||||
def broadcast_propagation(self, input_node_metas: "list[NodeBase]"):
|
||||
"""
|
||||
Propagate the broadcast in the reversed topological order.
|
||||
For example:
|
||||
C[L, M, N] = A[M, 1] + B[L, M, N]
|
||||
After the broadcast propagation, it will be come
|
||||
C[L, M, N] = A[L, M, N] + B[L, M, N]
|
||||
and each tensor will have a proper stride accessing the underlying tensor
|
||||
"""
|
||||
if self.tensor is None:
|
||||
raise RuntimeError(f"The tensor of node {self.name} is unknown.")
|
||||
for child in input_node_metas:
|
||||
child.tensor.broadcast(self.tensor.shape)
|
||||
|
||||
def get_underlying_impl(self, problem_size: tuple):
|
||||
"""
|
||||
Get the underlying implementation of the current node.
|
||||
"""
|
||||
if self.tensor is None:
|
||||
raise RuntimeError(
|
||||
f"The Layout of node {self.name} is unknown. Please call PassShapeTypePropagation first."
|
||||
)
|
||||
|
||||
for impl in self.possible_impls:
|
||||
if impl.match(self, problem_size):
|
||||
self.underlying_impl = impl(self)
|
||||
break
|
||||
|
||||
if self.underlying_impl is None:
|
||||
raise NotImplementedError(
|
||||
f"No matching op for node {self.name} with stride {self.tensor.stride}."
|
||||
)
|
||||
|
||||
|
||||
#
|
||||
# Visitor Nodes & Impls
|
||||
#
|
||||
|
||||
|
||||
class TopoVisitorImpl(ImplBase):
|
||||
"""
|
||||
Impl for topological visitor
|
||||
"""
|
||||
|
||||
def __init__(self, node) -> None:
|
||||
super().__init__(node.output_node)
|
||||
self.name = node.name
|
||||
self.element_output = node.output_node.element_output
|
||||
|
||||
|
||||
class TopoVisitorNode(NodeBase):
|
||||
def __init__(self, name: str, subgraph, output_node) -> None:
|
||||
super().__init__(name)
|
||||
self.subgraph = subgraph
|
||||
self.output_node = output_node
|
||||
self.op = "dag"
|
||||
self.underlying_impl = TopoVisitorImpl(self)
|
||||
276
python/cutlass_api/cutlass_api/fusion/ir/store_nodes.py
Normal file
276
python/cutlass_api/cutlass_api/fusion/ir/store_nodes.py
Normal file
@@ -0,0 +1,276 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Store node and implementations
|
||||
"""
|
||||
|
||||
import ctypes
|
||||
|
||||
from cutlass_api.fusion.library import DataType, FloatRoundStyle, FunctionalOp
|
||||
from cutlass_api.fusion.ir.c_types import dtype2ctype, to_ctype_value, tuple_factory
|
||||
from cutlass_api.fusion.ir.node import NodeBase, ImplBase, NoOpImpl
|
||||
from cutlass_api.fusion.ir.tensor import Tensor
|
||||
|
||||
|
||||
class StoreImplBase(ImplBase):
|
||||
"""
|
||||
Base class for store node implementation
|
||||
"""
|
||||
|
||||
reserved_names = ["D"]
|
||||
|
||||
def __init__(self, node) -> None:
|
||||
super().__init__(node)
|
||||
self.element = node.element
|
||||
self.element_output = node.element_output
|
||||
self.stride = node.store_tensor.stride
|
||||
|
||||
|
||||
class StoreDImpl(StoreImplBase):
|
||||
"""
|
||||
Store D implementation
|
||||
"""
|
||||
|
||||
@property
|
||||
def argument_type_d(self):
|
||||
stride_mnl = self.get_stride_mnl()
|
||||
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
|
||||
|
||||
class _Argument(ctypes.Structure):
|
||||
_fields_ = [("ptr_D", ctypes.c_void_p), ("stride_D", tuple_type)]
|
||||
|
||||
def __init__(self, ptr: int) -> None:
|
||||
self.ptr_D = ptr
|
||||
self.stride_D = tuple_type(stride_mnl)
|
||||
|
||||
return _Argument
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
if node.name == "D" and node.store_tensor.shape == problem_size:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class AuxStoreImpl(StoreImplBase):
|
||||
def __init__(self, node) -> None:
|
||||
super().__init__(node)
|
||||
self.round_style = FloatRoundStyle.ToNearest
|
||||
|
||||
@property
|
||||
def argument_type(self):
|
||||
stride_mnl = self.get_stride_mnl()
|
||||
name = self.name
|
||||
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
|
||||
|
||||
class _Argument(ctypes.Structure):
|
||||
_fields_ = [("ptr_aux", ctypes.c_void_p), ("dAux", tuple_type)]
|
||||
|
||||
def __init__(self, kwargs) -> None:
|
||||
ptr = kwargs[name]
|
||||
self.ptr_aux = ptr
|
||||
self.dAux = tuple_type(stride_mnl)
|
||||
|
||||
return _Argument
|
||||
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
if not node.is_output:
|
||||
return False
|
||||
if node.name in StoreImplBase.reserved_names:
|
||||
return False
|
||||
|
||||
strideMN = node.store_tensor.stride[-2:]
|
||||
if 1 not in strideMN or 0 in strideMN:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class ReductionImplBase(StoreImplBase):
|
||||
def __init__(self, node) -> None:
|
||||
super().__init__(node)
|
||||
self.element = node.store_tensor.element
|
||||
self.element_compute = node.element_compute
|
||||
self.reg_reduce_fn = self.node.reg_reduce_fn
|
||||
self.gmem_reduce_fn = self.node.gmem_reduce_fn
|
||||
self.round_style = node.round_style
|
||||
self.stride_dtype = "int"
|
||||
|
||||
def get_reduce_identity(self):
|
||||
"""
|
||||
Return the reduction identity of the current reduce_fn
|
||||
"""
|
||||
maxes = {
|
||||
DataType.f32: (2**31) - 1,
|
||||
DataType.f16: (2**15),
|
||||
DataType.s32: (2**31) - 1,
|
||||
DataType.s8: (2**7) - 1,
|
||||
}
|
||||
mins = {
|
||||
DataType.f32: -maxes[DataType.f32],
|
||||
DataType.f16: -maxes[DataType.f16],
|
||||
DataType.s32: -maxes[DataType.s32],
|
||||
DataType.s8: -maxes[DataType.s8],
|
||||
}
|
||||
if self.reg_reduce_fn == FunctionalOp.Maximum:
|
||||
if self.element_compute not in mins:
|
||||
raise Exception(f"No min entry for data type {self.element_compute}")
|
||||
return to_ctype_value(mins[self.element_compute], self.element_compute)
|
||||
elif self.reg_reduce_fn == FunctionalOp.Multiplies:
|
||||
return to_ctype_value(1.0, self.element_compute)
|
||||
elif self.reg_reduce_fn == FunctionalOp.Minimum:
|
||||
if self.element_compute not in maxes:
|
||||
raise Exception(f"No max entry for data type {self.element_compute}")
|
||||
return to_ctype_value(maxes[self.element_compute], self.element_compute)
|
||||
else:
|
||||
return to_ctype_value(0.0, self.element_compute)
|
||||
|
||||
@property
|
||||
def argument_type(self):
|
||||
self.get_reduce_identity()
|
||||
stride_mnl = self.get_stride_mnl()
|
||||
name = self.name
|
||||
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
|
||||
element_compute = self.element_compute
|
||||
reduce_identity = self.get_reduce_identity()
|
||||
|
||||
class _Argument(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("ptr", ctypes.c_void_p),
|
||||
("reduce_identity", dtype2ctype[element_compute]),
|
||||
("dMNL", tuple_type),
|
||||
]
|
||||
|
||||
def __init__(self, kwargs) -> None:
|
||||
ptr = kwargs[name]
|
||||
self.ptr = ptr
|
||||
self.reduce_identity = reduce_identity
|
||||
self.dMNL = tuple_type(stride_mnl)
|
||||
|
||||
return _Argument
|
||||
|
||||
|
||||
class ColumnReductionImpl(ReductionImplBase):
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
if not node.is_output:
|
||||
return False
|
||||
if node.name in StoreImplBase.reserved_names:
|
||||
return False
|
||||
|
||||
strideMN = node.store_tensor.stride[-2:]
|
||||
if strideMN == (1, 0):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class RowReductionImpl(ReductionImplBase):
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
if not node.is_output:
|
||||
return False
|
||||
if node.name in StoreImplBase.reserved_names:
|
||||
return False
|
||||
|
||||
strideMN = node.store_tensor.stride[-2:]
|
||||
if strideMN == (0, 1):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class ScalarReductionImpl(ReductionImplBase):
|
||||
@staticmethod
|
||||
def match(node, problem_size: tuple):
|
||||
if not node.is_output:
|
||||
return False
|
||||
if node.name in StoreImplBase.reserved_names:
|
||||
return False
|
||||
|
||||
strideMN = node.store_tensor.stride[-2:]
|
||||
if strideMN == (0, 0):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class StoreNode(NodeBase):
|
||||
"""
|
||||
Store node
|
||||
"""
|
||||
|
||||
possible_impls = [
|
||||
AuxStoreImpl,
|
||||
RowReductionImpl,
|
||||
ColumnReductionImpl,
|
||||
ScalarReductionImpl,
|
||||
NoOpImpl,
|
||||
StoreDImpl,
|
||||
]
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
super().__init__(name)
|
||||
self.op = "store"
|
||||
self.is_output = False
|
||||
self._store_tensor = None
|
||||
|
||||
@property
|
||||
def store_tensor(self) -> Tensor:
|
||||
"""
|
||||
Return the output tensor (concept: cutlass_api.fusion.ir.tensor)
|
||||
"""
|
||||
return self._store_tensor
|
||||
|
||||
@store_tensor.setter
|
||||
def store_tensor(self, kwargs):
|
||||
"""
|
||||
Setting the tensor
|
||||
"""
|
||||
self._store_tensor = Tensor(**kwargs)
|
||||
|
||||
def type_propagation(self, input_node_metas: "list[NodeBase]"):
|
||||
"""
|
||||
The store nodes has element_output = element_input
|
||||
"""
|
||||
if self.is_output:
|
||||
if self.store_tensor is None:
|
||||
raise RuntimeError(f"The store tensor of node {self.name} is unknown.")
|
||||
self.element = self.store_tensor.element
|
||||
assert len(input_node_metas) == 1, "Store node can only have one input node"
|
||||
self.element_output = input_node_metas[0].element_output
|
||||
|
||||
def broadcast_propagation(self, input_node_metas: "list[NodeBase]"):
|
||||
super().broadcast_propagation(input_node_metas)
|
||||
if self.is_output:
|
||||
self._store_tensor.broadcast(self.tensor.shape)
|
||||
155
python/cutlass_api/cutlass_api/fusion/ir/tensor.py
Normal file
155
python/cutlass_api/cutlass_api/fusion/ir/tensor.py
Normal file
@@ -0,0 +1,155 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
High-level class for tensor
|
||||
"""
|
||||
|
||||
from cutlass_api.fusion.library import (
|
||||
LayoutType,
|
||||
get_datatype_and_layout,
|
||||
get_tensor_shape,
|
||||
library_type,
|
||||
)
|
||||
from cutlass_api.fusion.ir.layout_algorithm import (
|
||||
Layout,
|
||||
broadcast,
|
||||
canonicalization,
|
||||
permutation,
|
||||
reshape,
|
||||
_reverse_tuple,
|
||||
)
|
||||
|
||||
|
||||
class Tensor:
|
||||
"""
|
||||
The tensor abstracts the data type
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tensor=None,
|
||||
element=None,
|
||||
shape=None,
|
||||
stride=None,
|
||||
layout_tag=None,
|
||||
is_constant=False,
|
||||
) -> None:
|
||||
if element is not None and tensor is not None:
|
||||
raise Exception("Must not specify both element and tensor")
|
||||
elif shape is not None and tensor is not None:
|
||||
raise Exception("Must not specify both shape and tensor")
|
||||
elif layout_tag is not None and tensor is not None:
|
||||
raise Exception("Must not specify both layout_tag and tensor")
|
||||
elif (
|
||||
element is None or (layout_tag is None and stride is None) or shape is None
|
||||
) and (tensor is None):
|
||||
raise Exception(
|
||||
"Must specify one of (element, shape, layout/stride) or (tensor)"
|
||||
)
|
||||
elif stride is not None and tensor is not None:
|
||||
raise Exception("Must not specify both stride and tensor")
|
||||
elif stride is not None and layout_tag is not None:
|
||||
raise Exception("Must not specify layout_tag when stride is provided")
|
||||
|
||||
if isinstance(tensor, Tensor):
|
||||
# Directly copy all the attributes
|
||||
self.__dict__.update(vars(tensor))
|
||||
else:
|
||||
if tensor is None:
|
||||
self.element = library_type(element)
|
||||
else:
|
||||
self.element, layout_tag = get_datatype_and_layout(tensor)
|
||||
shape = get_tensor_shape(tensor)
|
||||
if stride is not None:
|
||||
self.layout = Layout(shape[::-1], stride[::-1])
|
||||
else:
|
||||
if layout_tag == LayoutType.RowMajor:
|
||||
self.layout = Layout(shape[::-1])
|
||||
elif layout_tag == LayoutType.ColumnMajor:
|
||||
self.layout = permutation(
|
||||
Layout(shape), [idx for idx in reversed(range(len(shape)))]
|
||||
)
|
||||
self.layout = canonicalization(self.layout)
|
||||
|
||||
self.is_constant = is_constant
|
||||
# Save the tensor value if it is constant
|
||||
if is_constant and tensor is not None:
|
||||
self.value = tensor
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
"""
|
||||
Returns the RowMajor layout shape
|
||||
"""
|
||||
return _reverse_tuple(self.layout.shape)
|
||||
|
||||
@property
|
||||
def stride(self):
|
||||
"""
|
||||
Returns the RowMajor layout stride
|
||||
"""
|
||||
return _reverse_tuple(self.layout.stride)
|
||||
|
||||
@property
|
||||
def rank(self):
|
||||
"""
|
||||
Returns the rank of the tensor
|
||||
"""
|
||||
return len(self.shape)
|
||||
|
||||
#
|
||||
# Layout Algorithms
|
||||
#
|
||||
|
||||
def broadcast(self, shape):
|
||||
"""
|
||||
Broadcast self.layout to shape
|
||||
"""
|
||||
assert isinstance(shape, tuple)
|
||||
self.layout = broadcast(self.layout, _reverse_tuple(shape))
|
||||
|
||||
def reshape(self, shape):
|
||||
"""
|
||||
Reshape self.layout to shape
|
||||
"""
|
||||
assert isinstance(shape, tuple)
|
||||
reverse_shape = _reverse_tuple(shape)
|
||||
self.layout = reshape(self.layout, reverse_shape)
|
||||
|
||||
def permute(self, indices):
|
||||
"""
|
||||
Permute self.layout according to indices
|
||||
"""
|
||||
length = len(indices)
|
||||
indices = [length - idx - 1 for idx in indices]
|
||||
self.layout = permutation(self.layout, indices[::-1])
|
||||
441
python/cutlass_api/cutlass_api/fusion/library.py
Normal file
441
python/cutlass_api/cutlass_api/fusion/library.py
Normal file
@@ -0,0 +1,441 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Copies of enum types from cutlass_library that are used in the fusion frontend.
|
||||
"""
|
||||
|
||||
import enum
|
||||
from enum import auto as enum_auto
|
||||
|
||||
import cutlass
|
||||
|
||||
from cutlass_api.utils import (
|
||||
is_numpy_available,
|
||||
is_numpy_tensor,
|
||||
is_torch_available,
|
||||
is_torch_tensor,
|
||||
)
|
||||
|
||||
|
||||
class DataType(enum.Enum):
|
||||
void = enum_auto() # primarily used to disable C tensor for epilogues
|
||||
b1 = enum_auto()
|
||||
u2 = enum_auto()
|
||||
u4 = enum_auto()
|
||||
u8 = enum_auto()
|
||||
u16 = enum_auto()
|
||||
u32 = enum_auto()
|
||||
u64 = enum_auto()
|
||||
s2 = enum_auto()
|
||||
s4 = enum_auto()
|
||||
s8 = enum_auto()
|
||||
s16 = enum_auto()
|
||||
s32 = enum_auto()
|
||||
s64 = enum_auto()
|
||||
e4m3 = enum_auto()
|
||||
e5m2 = enum_auto()
|
||||
f8 = enum_auto()
|
||||
f6 = enum_auto()
|
||||
f4 = enum_auto()
|
||||
e3m2 = enum_auto()
|
||||
e2m3 = enum_auto()
|
||||
e2m1 = enum_auto()
|
||||
ue8m0 = enum_auto()
|
||||
ue4m3 = enum_auto()
|
||||
f16 = enum_auto()
|
||||
bf16 = enum_auto()
|
||||
f32 = enum_auto()
|
||||
tf32 = enum_auto()
|
||||
f64 = enum_auto()
|
||||
cf16 = enum_auto()
|
||||
cbf16 = enum_auto()
|
||||
cf32 = enum_auto()
|
||||
ctf32 = enum_auto()
|
||||
cf64 = enum_auto()
|
||||
cs2 = enum_auto()
|
||||
cs4 = enum_auto()
|
||||
cs8 = enum_auto()
|
||||
cs16 = enum_auto()
|
||||
cs32 = enum_auto()
|
||||
cs64 = enum_auto()
|
||||
cu2 = enum_auto()
|
||||
cu4 = enum_auto()
|
||||
cu8 = enum_auto()
|
||||
cu16 = enum_auto()
|
||||
cu32 = enum_auto()
|
||||
cu64 = enum_auto()
|
||||
invalid = enum_auto()
|
||||
|
||||
|
||||
DataTypeSize = {
|
||||
DataType.void: 0,
|
||||
DataType.b1: 1,
|
||||
DataType.u2: 2,
|
||||
DataType.u4: 4,
|
||||
DataType.u8: 8,
|
||||
DataType.u16: 16,
|
||||
DataType.u32: 32,
|
||||
DataType.u64: 64,
|
||||
DataType.s2: 2,
|
||||
DataType.s4: 4,
|
||||
DataType.s8: 8,
|
||||
DataType.s16: 16,
|
||||
DataType.s32: 32,
|
||||
DataType.s64: 64,
|
||||
DataType.e4m3: 8,
|
||||
DataType.e5m2: 8,
|
||||
DataType.f8: 8,
|
||||
DataType.f6: 6,
|
||||
DataType.f4: 4,
|
||||
DataType.e2m3: 6,
|
||||
DataType.e3m2: 6,
|
||||
DataType.e2m1: 4,
|
||||
DataType.ue8m0: 8,
|
||||
DataType.ue4m3: 8,
|
||||
DataType.f16: 16,
|
||||
DataType.bf16: 16,
|
||||
DataType.f32: 32,
|
||||
DataType.tf32: 32,
|
||||
DataType.f64: 64,
|
||||
DataType.cf16: 32,
|
||||
DataType.cbf16: 32,
|
||||
DataType.cf32: 64,
|
||||
DataType.ctf32: 32,
|
||||
DataType.cf64: 128,
|
||||
DataType.cu2: 4,
|
||||
DataType.cu4: 8,
|
||||
DataType.cu8: 16,
|
||||
DataType.cu16: 32,
|
||||
DataType.cu32: 64,
|
||||
DataType.cu64: 128,
|
||||
DataType.cs2: 4,
|
||||
DataType.cs4: 8,
|
||||
DataType.cs8: 16,
|
||||
DataType.cs16: 32,
|
||||
DataType.cs32: 64,
|
||||
DataType.cs64: 128,
|
||||
}
|
||||
|
||||
|
||||
class LayoutType(enum.Enum):
|
||||
ColumnMajor = enum_auto()
|
||||
RowMajor = enum_auto()
|
||||
ColumnMajorInterleaved2 = enum_auto()
|
||||
RowMajorInterleaved2 = enum_auto()
|
||||
ColumnMajorInterleaved32 = enum_auto()
|
||||
RowMajorInterleaved32 = enum_auto()
|
||||
ColumnMajorInterleaved64 = enum_auto()
|
||||
RowMajorInterleaved64 = enum_auto()
|
||||
TensorNWC = enum_auto()
|
||||
TensorNHWC = enum_auto()
|
||||
TensorNDHWC = enum_auto()
|
||||
TensorNCHW = enum_auto()
|
||||
TensorNGHWC = enum_auto()
|
||||
TensorNC32HW32 = enum_auto()
|
||||
TensorNC64HW64 = enum_auto()
|
||||
TensorC32RSK32 = enum_auto()
|
||||
TensorC64RSK64 = enum_auto()
|
||||
TensorKCS = enum_auto()
|
||||
TensorKCSR = enum_auto()
|
||||
TensorKCSRT = enum_auto()
|
||||
|
||||
|
||||
class EpilogueScheduleType(enum.Enum):
|
||||
ScheduleAuto = enum_auto()
|
||||
EpilogueTransposed = enum_auto()
|
||||
NoSmemWarpSpecialized = enum_auto()
|
||||
PtrArrayNoSmemWarpSpecialized = enum_auto()
|
||||
NoSmemWarpSpecialized1Sm = enum_auto()
|
||||
NoSmemWarpSpecialized2Sm = enum_auto()
|
||||
FastF32NoSmemWarpSpecialized1Sm = enum_auto()
|
||||
FastF32NoSmemWarpSpecialized2Sm = enum_auto()
|
||||
BlockwiseNoSmemWarpSpecialized1Sm = enum_auto()
|
||||
BlockwiseNoSmemWarpSpecialized2Sm = enum_auto()
|
||||
PtrArrayNoSmemWarpSpecialized1Sm = enum_auto()
|
||||
PtrArrayNoSmemWarpSpecialized2Sm = enum_auto()
|
||||
PtrArrayFastF32NoSmemWarpSpecialized1Sm = enum_auto()
|
||||
PtrArrayFastF32NoSmemWarpSpecialized2Sm = enum_auto()
|
||||
PtrArrayBlockwiseNoSmemWarpSpecialized1Sm = enum_auto()
|
||||
PtrArrayBlockwiseNoSmemWarpSpecialized2Sm = enum_auto()
|
||||
TmaWarpSpecialized = enum_auto()
|
||||
TmaWarpSpecializedCooperative = enum_auto()
|
||||
TmaWarpSpecialized1Sm = enum_auto()
|
||||
TmaWarpSpecialized2Sm = enum_auto()
|
||||
PtrArrayTmaWarpSpecialized1Sm = enum_auto()
|
||||
PtrArrayTmaWarpSpecialized2Sm = enum_auto()
|
||||
PtrArrayTmaWarpSpecializedPingpong = enum_auto()
|
||||
PtrArrayTmaWarpSpecializedCooperative = enum_auto()
|
||||
TmaWarpSpecialized1SmNvf4 = enum_auto()
|
||||
TmaWarpSpecialized2SmNvf4 = enum_auto()
|
||||
TmaWarpSpecialized1SmMxf4 = enum_auto()
|
||||
TmaWarpSpecialized2SmMxf4 = enum_auto()
|
||||
TmaWarpSpecialized1SmMxf8f6f4 = enum_auto()
|
||||
TmaWarpSpecialized2SmMxf8f6f4 = enum_auto()
|
||||
SparseTmaWarpSpecializedCooperativeSm120 = enum_auto()
|
||||
|
||||
|
||||
class ActivationOp(enum.Enum):
|
||||
DGelu = enum_auto()
|
||||
Gelu = enum_auto()
|
||||
GeluTaylor = enum_auto()
|
||||
HardSwish = enum_auto()
|
||||
Identity = enum_auto()
|
||||
LeakyReLU = enum_auto()
|
||||
ReLU = enum_auto()
|
||||
Sigmoid = enum_auto()
|
||||
SiLU = enum_auto()
|
||||
Tanh = enum_auto()
|
||||
|
||||
|
||||
ActivationOpTag = {
|
||||
ActivationOp.DGelu: "cutlass::epilogue::thread::dGELU",
|
||||
ActivationOp.Gelu: "cutlass::epilogue::thread::GELU",
|
||||
ActivationOp.GeluTaylor: "cutlass::epilogue::thread::GELU_taylor",
|
||||
ActivationOp.HardSwish: "cutlass::epilogue::thread::HardSwish",
|
||||
ActivationOp.Identity: "cutlass::epilogue::thread::Identity",
|
||||
ActivationOp.LeakyReLU: "cutlass::epilogue::thread::LeakyReLU",
|
||||
ActivationOp.ReLU: "cutlass::epilogue::thread::ReLu",
|
||||
ActivationOp.Sigmoid: "cutlass::epilogue::thread::Sigmoid",
|
||||
ActivationOp.SiLU: "cutlass::epilogue::thread::SiLu",
|
||||
ActivationOp.Tanh: "cutlass::epilogue::thread::Tanh",
|
||||
}
|
||||
|
||||
|
||||
class FloatRoundStyle(enum.Enum):
|
||||
ToNearest = enum_auto()
|
||||
ToNearestSatfinite = enum_auto()
|
||||
Indeterminate = enum_auto()
|
||||
TowardZero = enum_auto()
|
||||
TowardInfinity = enum_auto()
|
||||
TowardNegInfinity = enum_auto()
|
||||
HalfUlpTruncDntz = enum_auto()
|
||||
HalfUlpTruncate = enum_auto()
|
||||
|
||||
|
||||
class FunctionalOp(enum.Enum):
|
||||
AtomicAdd = enum_auto()
|
||||
AtomicMaximum = enum_auto()
|
||||
Divides = enum_auto()
|
||||
Maximum = enum_auto()
|
||||
Minimum = enum_auto()
|
||||
Minus = enum_auto()
|
||||
Multiplies = enum_auto()
|
||||
MultiplyAdd = enum_auto()
|
||||
Plus = enum_auto()
|
||||
Exp = enum_auto()
|
||||
|
||||
|
||||
FunctionalOpTag = {
|
||||
FunctionalOp.AtomicAdd: "cutlass::atomic_add",
|
||||
FunctionalOp.AtomicMaximum: "cutlass::atomic_maximum",
|
||||
FunctionalOp.Divides: "cutlass::divides",
|
||||
FunctionalOp.Maximum: "cutlass::maximum",
|
||||
FunctionalOp.Minimum: "cutlass::minimum",
|
||||
FunctionalOp.Minus: "cutlass::minus",
|
||||
FunctionalOp.Multiplies: "cutlass::multiplies",
|
||||
FunctionalOp.MultiplyAdd: "cutlass::multiply_add",
|
||||
FunctionalOp.Plus: "cutlass::plus",
|
||||
FunctionalOp.Exp: "cutlass::fast_exp_op",
|
||||
}
|
||||
|
||||
|
||||
def op_tag(op) -> str:
|
||||
"""
|
||||
Dispatches `op` to the appropriate *Tag dictionary depending on whether
|
||||
`op` is an ActivationOp or FunctionalOp. This is useful for cases in which
|
||||
either type can be used.
|
||||
|
||||
:param op: operation to emit a tag for
|
||||
:type op: ActivationOp | FunctionalOp
|
||||
|
||||
:return: tag corresponding to op
|
||||
:rtype: str
|
||||
"""
|
||||
if isinstance(op, ActivationOp):
|
||||
return ActivationOpTag[op]
|
||||
elif isinstance(op, FunctionalOp):
|
||||
return FunctionalOpTag[op]
|
||||
else:
|
||||
raise Exception(
|
||||
f"Unexpected op type {op}. Must be one of ActivationOp or FunctionalOp."
|
||||
)
|
||||
|
||||
|
||||
# The current EVT frontend also contains code needed for emitting C++ EVTs.
|
||||
# C++ emission is not currently supported by cutlass_api, but we still
|
||||
# need many of the utilities surrounding it. These utilities make use of dictionaries
|
||||
# from cutlass_library types to strings containing C++ code. We define placeholders
|
||||
# for empty versions of these dictionaries that raise an error if they are used.
|
||||
class _UnimplementedDict:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def __getitem__(self, key):
|
||||
raise NotImplementedError(
|
||||
f"Dictionary {self.name} is not implemented. This code path should not have been reafched."
|
||||
)
|
||||
|
||||
|
||||
DataTypeTag = _UnimplementedDict("DataTypeTag")
|
||||
EpilogueScheduleTag = _UnimplementedDict("EpilogueScheduleTag")
|
||||
FloatRoundStyleTag = _UnimplementedDict("FloatRoundStyleTag")
|
||||
KernelScheduleSuffixes = _UnimplementedDict("KernelScheduleSuffixes")
|
||||
OpcodeClassTag = _UnimplementedDict("OpcodeClassTag")
|
||||
|
||||
|
||||
_torch_to_library_dict = None
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
_torch_to_library_dict = {
|
||||
torch.half: DataType.f16,
|
||||
torch.float16: DataType.f16,
|
||||
torch.bfloat16: DataType.bf16,
|
||||
torch.float: DataType.f32,
|
||||
torch.float32: DataType.f32,
|
||||
}
|
||||
|
||||
|
||||
def _tensor_from_numpy(np_tensor) -> tuple[DataType, LayoutType]:
|
||||
dtype = library_type(np_tensor.dtype)
|
||||
if np_tensor.flags.c_contiguous:
|
||||
layout = LayoutType.RowMajor
|
||||
elif np_tensor.flags.f_contiguous:
|
||||
layout = LayoutType.ColumnMajor
|
||||
return (dtype, layout)
|
||||
|
||||
|
||||
def _tensor_from_torch(pt_tensor):
|
||||
dtype = library_type(pt_tensor.dtype)
|
||||
return (dtype, LayoutType.RowMajor)
|
||||
|
||||
|
||||
def torch_library_type(inp) -> DataType:
|
||||
if _torch_to_library_dict is None:
|
||||
return None
|
||||
return _torch_to_library_dict.get(inp, None)
|
||||
|
||||
|
||||
def numpy_library_type(inp) -> DataType:
|
||||
if is_numpy_available():
|
||||
import numpy as np
|
||||
|
||||
if inp == np.float16:
|
||||
return DataType.f16
|
||||
elif inp == np.float32:
|
||||
return DataType.f32
|
||||
elif inp == np.float64:
|
||||
return DataType.f64
|
||||
elif inp == np.int8:
|
||||
return DataType.s8
|
||||
elif inp == np.int32:
|
||||
return DataType.s32
|
||||
return None
|
||||
|
||||
|
||||
def cutlass_library_type(inp: cutlass.Numeric) -> DataType:
|
||||
if inp == cutlass.Float32:
|
||||
return DataType.f32
|
||||
elif inp == cutlass.Float16:
|
||||
return DataType.f16
|
||||
elif inp == cutlass.BFloat16:
|
||||
return DataType.bf16
|
||||
elif inp == cutlass.Int32:
|
||||
return DataType.s32
|
||||
elif inp == cutlass.Int8:
|
||||
return DataType.s8
|
||||
elif inp == cutlass.Uint8:
|
||||
return DataType.u8
|
||||
elif inp == cutlass.Uint16:
|
||||
return DataType.u16
|
||||
elif inp == cutlass.Uint32:
|
||||
return DataType.u32
|
||||
elif inp == cutlass.Uint64:
|
||||
return DataType.u64
|
||||
elif inp == cutlass.Int16:
|
||||
return DataType.s16
|
||||
elif inp == cutlass.Int64:
|
||||
return DataType.s64
|
||||
elif inp == cutlass.Float8E5M2:
|
||||
return DataType.e5m2
|
||||
elif inp == cutlass.Float8E4M3FN:
|
||||
return DataType.e4m3
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def library_type(inp):
|
||||
if inp in DataTypeSize:
|
||||
return inp
|
||||
|
||||
for cvt_fn in [
|
||||
numpy_library_type,
|
||||
torch_library_type,
|
||||
cutlass_library_type,
|
||||
]:
|
||||
out = cvt_fn(inp)
|
||||
if out is not None:
|
||||
return out
|
||||
|
||||
raise ValueError(f"No available conversion from type {inp} to a library type.")
|
||||
|
||||
|
||||
def get_datatype_and_layout(tensor) -> tuple[DataType, LayoutType]:
|
||||
if is_numpy_tensor(tensor):
|
||||
return _tensor_from_numpy(tensor)
|
||||
elif is_torch_tensor(tensor):
|
||||
return _tensor_from_torch(tensor)
|
||||
elif isinstance(tensor, float) or isinstance(tensor, int):
|
||||
return (DataType.f32, LayoutType.RowMajor)
|
||||
else:
|
||||
raise Exception(
|
||||
f"Unable to convert tensor of type {type(tensor)} to Python-bound CUTLASS datatype and layout."
|
||||
)
|
||||
|
||||
|
||||
def get_tensor_shape(tensor, op="GEMM") -> tuple:
|
||||
if is_numpy_tensor(tensor):
|
||||
return tensor.shape
|
||||
elif is_torch_tensor(tensor):
|
||||
size = tensor.size()
|
||||
if op == "CONV":
|
||||
# PyTorch Tensors have shape NCHW
|
||||
return (size[0], size[2], size[3], size[1])
|
||||
else:
|
||||
return tuple(tensor.size())
|
||||
elif isinstance(tensor, float) or isinstance(tensor, int):
|
||||
return (1,)
|
||||
else:
|
||||
raise Exception(
|
||||
f"Unable to convert tensor of type {type(tensor)} to Python-bound CUTLASS datatype and layout."
|
||||
)
|
||||
59
python/cutlass_api/cutlass_api/fusion/passes/__init__.py
Normal file
59
python/cutlass_api/cutlass_api/fusion/passes/__init__.py
Normal file
@@ -0,0 +1,59 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from cutlass_api.fusion.passes.graph_drawer import EVTGraphDrawer
|
||||
from cutlass_api.fusion.passes.pass_argument_type import PassGetArgumentType
|
||||
from cutlass_api.fusion.passes.pass_dag_2_tree import PassDAG2Tree
|
||||
from cutlass_api.fusion.passes.pass_get_impl import PassGetImpl
|
||||
from cutlass_api.fusion.passes.pass_fix_element_d import PassFixElementD
|
||||
from cutlass_api.fusion.passes.pass_layout_elimination import (
|
||||
PassLayoutManipulateElimination,
|
||||
)
|
||||
from cutlass_api.fusion.passes.pass_manager import EVTPassManager
|
||||
from cutlass_api.fusion.passes.pass_preprocess_red import PassPreprocessRed
|
||||
from cutlass_api.fusion.passes.pass_shape_type_propagation import (
|
||||
PassShapeTypePropagation,
|
||||
)
|
||||
from cutlass_api.fusion.passes.smem_size_calculator import GetSmemSize
|
||||
|
||||
__all__ = [
|
||||
"EVTGraphDrawer",
|
||||
"PassGetArgumentType",
|
||||
"PassDAG2Tree",
|
||||
"PassGetImpl",
|
||||
"PassFixElementD",
|
||||
"PassLayoutManipulateElimination",
|
||||
"EVTPassManager",
|
||||
"PassPreprocessRed",
|
||||
"PassShapeTypePropagation",
|
||||
"GetSmemSize",
|
||||
]
|
||||
133
python/cutlass_api/cutlass_api/fusion/passes/graph_drawer.py
Normal file
133
python/cutlass_api/cutlass_api/fusion/passes/graph_drawer.py
Normal file
@@ -0,0 +1,133 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
from cutlass_api.fusion.library import DataTypeTag
|
||||
from cutlass_api.fusion.ir.dag_ir import DAGIR
|
||||
|
||||
|
||||
_COLOR_MAP = {
|
||||
"load": '"AliceBlue"',
|
||||
"compute": "LemonChiffon1",
|
||||
"accumulator": "LightGrey",
|
||||
"store": "PowderBlue",
|
||||
"layout": "lightseagreen",
|
||||
"dag": "darkorange",
|
||||
}
|
||||
|
||||
|
||||
class EVTGraphDrawer:
|
||||
"""
|
||||
Visualize a EVT DAGIR with graphviz
|
||||
"""
|
||||
|
||||
def __init__(self, graph: DAGIR, name: str):
|
||||
self._name = name
|
||||
self._dot_graphs = {}
|
||||
|
||||
self._dot_graphs[name] = self._to_dot(graph, name)
|
||||
|
||||
def _get_node_style(self, node):
|
||||
template = {
|
||||
"shape": "record",
|
||||
"fillcolor": "#CAFFE3",
|
||||
"style": '"filled,rounded"',
|
||||
"fontcolor": "#000000",
|
||||
}
|
||||
if node.op in _COLOR_MAP:
|
||||
template["fillcolor"] = _COLOR_MAP[node.op]
|
||||
else:
|
||||
raise NotImplementedError("unknown node op")
|
||||
if node.disabled:
|
||||
template["fontcolor"] = "grey"
|
||||
template["fillcolor"] = "white"
|
||||
return template
|
||||
|
||||
def _get_node_label(self, node):
|
||||
label = "{" + f"name={node.name}|op={node.op}"
|
||||
if node.op == "layout":
|
||||
label += f"|fn={node.fn.__name__}"
|
||||
for key in node.kwargs:
|
||||
label += f"|{key}={node.kwargs[key]}"
|
||||
if node.underlying_impl is not None:
|
||||
label += f"|impl={type(node.underlying_impl).__name__}"
|
||||
if node.op == "load":
|
||||
label += f"|element_output={DataTypeTag[node.underlying_impl.element]}"
|
||||
elif node.op == "compute":
|
||||
label += f"|element_compute={DataTypeTag[node.underlying_impl.element_compute]}|element_output={DataTypeTag[node.underlying_impl.element_output]}"
|
||||
elif node.op == "store":
|
||||
label += f"|element_store={DataTypeTag[node.underlying_impl.element]}|element_output={DataTypeTag[node.underlying_impl.element_output]}"
|
||||
label += f"|element_output={DataTypeTag[node.underlying_impl.element_output]}"
|
||||
if node.tensor is not None:
|
||||
shape = node.tensor.shape
|
||||
stride = node.tensor.stride
|
||||
label += f"|shape={shape}|stride={stride}"
|
||||
|
||||
if hasattr(node, "store_tensor") and node.store_tensor is not None:
|
||||
store_shape = node.store_tensor.shape
|
||||
store_stride = node.store_tensor.stride
|
||||
label += f"|store_shape={store_shape}|store_stride={store_stride}"
|
||||
|
||||
label += "}"
|
||||
return label
|
||||
|
||||
def _to_dot(self, graph: DAGIR, name: str):
|
||||
import pydot
|
||||
|
||||
dot_graph = pydot.Dot(name, rankdir="TB")
|
||||
for node in graph.nodes_meta:
|
||||
style = self._get_node_style(node)
|
||||
label = self._get_node_label(node)
|
||||
dot_node = pydot.Node(node.name, label=label, **style)
|
||||
dot_graph.add_node(dot_node)
|
||||
if node.op == "dag":
|
||||
dot_subgraph = self._to_dot(node.subgraph, name=node.name)
|
||||
self._dot_graphs[node.name] = dot_subgraph
|
||||
|
||||
# Add edges
|
||||
for src, dst in graph.edges:
|
||||
weight = graph.get_edge_weight(src, dst)
|
||||
dot_graph.add_edge(pydot.Edge(src, dst, label=weight))
|
||||
|
||||
return dot_graph
|
||||
|
||||
def get_dot_graph(self) -> "pydot.Dot":
|
||||
return [
|
||||
(key, self.get_dot_graph_by_name(key)) for key in self._dot_graphs.keys()
|
||||
]
|
||||
|
||||
def get_dot_graph_by_name(self, name) -> "pydot.Dot":
|
||||
return self._dot_graphs[name]
|
||||
|
||||
def get_main_dot_graph(self) -> "pydot.Dot":
|
||||
return self._dot_graphs[self._name]
|
||||
@@ -0,0 +1,136 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Construct the epilogue visitor argument type
|
||||
"""
|
||||
|
||||
from cutlass_api.fusion.ir import TopoVisitorNode
|
||||
from cutlass_api.fusion.ir.c_types import visitor_factory
|
||||
from cutlass_api.fusion.passes.pass_dag_2_tree import PassDAG2Tree
|
||||
from cutlass_api.fusion.passes.pass_get_impl import PassGetImpl
|
||||
from cutlass_api.fusion.passes.pass_manager import EVTPassBase
|
||||
from cutlass_api.fusion.passes.pass_shape_type_propagation import (
|
||||
PassShapeTypePropagation,
|
||||
)
|
||||
from cutlass_api.fusion.passes.util import cc_map
|
||||
|
||||
|
||||
class PassGetArgumentType(EVTPassBase):
|
||||
"""
|
||||
Construct the epilogue visitor argument type
|
||||
"""
|
||||
|
||||
dependencies = [
|
||||
PassShapeTypePropagation, # The Layout of all nodes must be set
|
||||
PassDAG2Tree, # The type of each node must be set
|
||||
PassGetImpl, # The DAG subgraphs must be set
|
||||
]
|
||||
|
||||
def requires(self) -> None:
|
||||
# Check "D" is in the node list
|
||||
if cc_map[self.cc] in [90, 100] and (not self.dag_ir.has_node("D")):
|
||||
raise SyntaxError(
|
||||
"Sm90+ EVT requires the epilogue to have a returned tensor D, "
|
||||
"but the variable 'D' is not found in the return values."
|
||||
)
|
||||
|
||||
def call(self):
|
||||
nodes = self.dag_ir.nodes_topological_order()
|
||||
self.argument_types = {}
|
||||
for node in nodes:
|
||||
meta = self.dag_ir.get_node_meta(node)
|
||||
if not meta.disabled:
|
||||
self.argument_types[node] = meta.underlying_impl.argument_type
|
||||
if node == "D" and cc_map[self.cc] in [90, 100]:
|
||||
continue
|
||||
if isinstance(meta, TopoVisitorNode):
|
||||
self.get_dag_argument_type(node)
|
||||
else:
|
||||
self.get_evt_argument_type(node)
|
||||
|
||||
self.cc_specific_method(self.set_argument_type)()
|
||||
|
||||
def get_evt_argument_type(self, node):
|
||||
# Sort the input nodes by edge weight
|
||||
input_types = [
|
||||
self.argument_types[child] for child in self.dag_ir.get_all_inputs(node)
|
||||
]
|
||||
if len(input_types) > 0:
|
||||
self.argument_types[node] = visitor_factory(
|
||||
input_types + [self.argument_types[node]],
|
||||
self.dag_ir.get_all_inputs(node) + [node],
|
||||
)
|
||||
|
||||
def get_dag_argument_type(self, node):
|
||||
meta = self.dag_ir.get_node_meta(node)
|
||||
subgraph = meta.subgraph
|
||||
subgraph_nodes = subgraph.nodes_topological_order()
|
||||
# Visit the unvisited nodes in subgraph
|
||||
for n in subgraph_nodes:
|
||||
M = subgraph.get_node_meta(n)
|
||||
if M.disabled:
|
||||
continue
|
||||
else:
|
||||
self.argument_types[n] = M.underlying_impl.argument_type
|
||||
input_types = [self.argument_types[child] for child in subgraph_nodes[:-1]]
|
||||
if len(input_types) > 0:
|
||||
self.argument_types[node] = visitor_factory(
|
||||
input_types, subgraph_nodes[:-1]
|
||||
)
|
||||
|
||||
def set_argument_type(self):
|
||||
pass
|
||||
|
||||
def sm90_set_argument_type(self):
|
||||
self.dag_ir.epilogue_thread_type = self.argument_types[
|
||||
self.dag_ir.get_all_inputs("D")[0]
|
||||
]
|
||||
# Get the tensorD argument type
|
||||
self.dag_ir.arg_d_type = self.dag_ir.get_node_meta(
|
||||
"D"
|
||||
).underlying_impl.argument_type_d
|
||||
|
||||
# Get the tensorC argument type
|
||||
if self.dag_ir.has_node("C"):
|
||||
self.dag_ir.arg_c_type = self.dag_ir.get_node_meta(
|
||||
"C"
|
||||
).underlying_impl.argument_type_c
|
||||
else:
|
||||
self.dag_ir.arg_c_type = self.dag_ir.arg_d_type
|
||||
|
||||
def sm100_set_argument_type(self):
|
||||
self.sm90_set_argument_type()
|
||||
|
||||
def sm80_set_argument_type(self):
|
||||
nodes = self.dag_ir.nodes_topological_order()
|
||||
self.dag_ir.epilogue_thread_type = self.argument_types[nodes[-1]]
|
||||
176
python/cutlass_api/cutlass_api/fusion/passes/pass_dag_2_tree.py
Normal file
176
python/cutlass_api/cutlass_api/fusion/passes/pass_dag_2_tree.py
Normal file
@@ -0,0 +1,176 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Merge non-tree sub-graphs of the DAG IR into a single DAG. The fused DAG will be implemented
|
||||
by the topological visitor, while the rest of the graph will be implemented with the tree visitor.
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
from cutlass_api.fusion.ir import DAGIR, TopoVisitorNode
|
||||
from cutlass_api.fusion.passes.pass_get_impl import PassGetImpl
|
||||
from cutlass_api.fusion.passes.pass_manager import EVTPassBase
|
||||
from cutlass_api.fusion.passes.pass_shape_type_propagation import (
|
||||
PassShapeTypePropagation,
|
||||
)
|
||||
|
||||
|
||||
class PassDAG2Tree(EVTPassBase):
|
||||
"""
|
||||
Convert the DAG IR to Tree by fusing subgraphs
|
||||
"""
|
||||
|
||||
dependencies = [PassShapeTypePropagation, PassGetImpl]
|
||||
|
||||
def call(self):
|
||||
# Step 1: find the nodes that have multiple parents
|
||||
multi_parent_nodes = []
|
||||
|
||||
for node in self.dag_ir.nodes_topological_order():
|
||||
if self.dag_ir.out_degree(node) > 1:
|
||||
multi_parent_nodes.append(node)
|
||||
# Step 2: find the lowest common ancestor (LCA) of all its parents
|
||||
for node in multi_parent_nodes:
|
||||
# A multi-parent node could be already fused by the previous node
|
||||
if not self.dag_ir.has_node(node):
|
||||
continue
|
||||
# A node uncovered by the previous fusions can have out degree change
|
||||
# Case 1: it has <= 1 edges to the previously fused subgraph, no degree change
|
||||
# Case 2: it has more than one edges to the previously fused subgraph, degree drops
|
||||
if self.dag_ir.out_degree(node) <= 1:
|
||||
continue
|
||||
|
||||
# Otherwise, the node still
|
||||
reachable_nodes = []
|
||||
# Complexity: O(Dout*N)
|
||||
for parent in self.dag_ir.get_users(node):
|
||||
reachable_nodes.append(set(self.dag_ir.all_reachable_nodes(parent)))
|
||||
# get the common reachable objects
|
||||
common_items = set.intersection(*reachable_nodes)
|
||||
node_to_fuse = set.union(*reachable_nodes).difference(common_items)
|
||||
|
||||
lca = None
|
||||
# If common ancestor exists, find the lowest one
|
||||
if len(common_items) > 0:
|
||||
topo_order = self.dag_ir.nodes_topological_order()
|
||||
topo_idx = -1
|
||||
for item in common_items:
|
||||
if lca is None:
|
||||
lca = item
|
||||
topo_idx = topo_order.index(item)
|
||||
else:
|
||||
if topo_idx > topo_order.index(item):
|
||||
lca = item
|
||||
topo_idx = topo_order.index(item)
|
||||
else:
|
||||
# there is no common ancestor for all the parents, we pack all the reachable
|
||||
# nodes into a single DAG node as a fallback. The lca should be the input node of
|
||||
# one of the output nodes with out_degree = 0
|
||||
potential_output_nodes = []
|
||||
for node in node_to_fuse:
|
||||
if self.dag_ir.out_degree(node) == 0:
|
||||
potential_output_nodes.append(node)
|
||||
if len(potential_output_nodes) == 0:
|
||||
raise RuntimeError("No output node with out degree = 0 found.")
|
||||
|
||||
output_node = None
|
||||
if self.dag_ir.cc >= 90:
|
||||
# For SM90+, the lca should be the input node of D
|
||||
if not self.dag_ir.has_node("D"):
|
||||
raise RuntimeError("D is not a node in the DAG IR.")
|
||||
output_node = "D"
|
||||
else:
|
||||
output_node = potential_output_nodes[0]
|
||||
|
||||
if output_node is None:
|
||||
raise RuntimeError("No output node found.")
|
||||
lca = self.dag_ir.get_all_inputs(output_node)[0]
|
||||
node_to_fuse.remove(output_node)
|
||||
|
||||
# The lca is the output node of the DAG node
|
||||
# Get the nodes to be fused
|
||||
node_to_fuse.add(lca)
|
||||
# Get all the input nodes
|
||||
all_input_nodes = []
|
||||
all_output_nodes = []
|
||||
for node in node_to_fuse:
|
||||
all_input_nodes.append(set(self.dag_ir.get_all_inputs(node)))
|
||||
all_output_nodes.append(set(self.dag_ir.get_users(node)))
|
||||
all_input_nodes = set.union(*all_input_nodes)
|
||||
all_output_nodes = set.union(*all_output_nodes)
|
||||
|
||||
new_subgraph_nodes = set.union(
|
||||
node_to_fuse, all_input_nodes, all_output_nodes
|
||||
)
|
||||
|
||||
# Create the subgraph
|
||||
subgraph_ = self.dag_ir._graph.subgraph(new_subgraph_nodes)
|
||||
subgraph = DAGIR(self.dag_ir.cc)
|
||||
for node in subgraph_.nodes:
|
||||
meta = deepcopy(self.dag_ir.get_node_meta(node))
|
||||
if node not in node_to_fuse:
|
||||
meta.disabled = True
|
||||
subgraph.add_node(meta)
|
||||
for edge in subgraph_.edges:
|
||||
subgraph.add_edge(
|
||||
edge[0], edge[1], self.dag_ir.get_edge_weight(edge[0], edge[1])
|
||||
)
|
||||
|
||||
# Create the fused node
|
||||
dag_node = TopoVisitorNode(
|
||||
name=f"dag_{lca}",
|
||||
subgraph=subgraph,
|
||||
output_node=self.dag_ir.get_node_meta(lca),
|
||||
)
|
||||
self.dag_ir.add_node(dag_node)
|
||||
|
||||
# Add input edges
|
||||
for idx, node in enumerate(all_input_nodes):
|
||||
self.dag_ir.add_edge(node, dag_node.name, weight=idx)
|
||||
|
||||
# Replace all uses with DAG node (only 1 output node)
|
||||
self.dag_ir.replace_all_uses_with(lca, dag_node.name)
|
||||
|
||||
# Remove all fused nodes
|
||||
node_to_fuse.remove(lca)
|
||||
for node in node_to_fuse:
|
||||
self.dag_ir.remove_node(node)
|
||||
|
||||
def ensures(self) -> None:
|
||||
# Ensure that after the pass, the resulting DAG becomes a tree
|
||||
for node in self.dag_ir.nodes:
|
||||
out_degree = self.dag_ir.out_degree(node)
|
||||
if out_degree > 1:
|
||||
raise RuntimeError(
|
||||
f"PassDAG2Tree failed. Node {node} still have outdegree = {out_degree}"
|
||||
)
|
||||
@@ -0,0 +1,86 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Fix the element_output of producer of D.
|
||||
|
||||
In Sm90 epilogue visitor, the node writing D to gmem does not have internal
|
||||
element converter, so the compute node producing D must have element_output = type(D).
|
||||
"""
|
||||
|
||||
from cutlass_api.fusion.passes.pass_layout_elimination import (
|
||||
PassLayoutManipulateElimination,
|
||||
)
|
||||
from cutlass_api.fusion.passes.pass_manager import EVTPassBase
|
||||
|
||||
|
||||
class PassFixElementD(EVTPassBase):
|
||||
"""
|
||||
In Sm90 epilogue visitor, the node writing D to gmem does not have internal
|
||||
element converter, so the compute node producing D must have
|
||||
element_output = type(D)
|
||||
"""
|
||||
|
||||
dependencies = [PassLayoutManipulateElimination]
|
||||
|
||||
def get_producer(self, node: str, element_D, visited=None):
|
||||
if visited is None:
|
||||
visited = set()
|
||||
|
||||
if node in visited:
|
||||
raise RuntimeError(
|
||||
f"Cycle detected while traversing to producer of D: {node}"
|
||||
)
|
||||
visited.add(node)
|
||||
|
||||
node_meta = self.dag_ir.get_node_meta(node)
|
||||
if node_meta.op == "compute":
|
||||
node_meta.element_output = element_D
|
||||
elif node_meta.op == "store":
|
||||
inputs = self.dag_ir.get_all_inputs(node)
|
||||
if len(inputs) != 1:
|
||||
raise RuntimeError(
|
||||
f"Store node {node} has {len(inputs)} inputs, expected 1"
|
||||
)
|
||||
self.get_producer(inputs[0], element_D, visited)
|
||||
elif node_meta.op == "load":
|
||||
node_meta.element_output = element_D
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported node op: {node_meta.op} when getting producer of D"
|
||||
)
|
||||
|
||||
def call(self):
|
||||
if self.dag_ir.has_node("D"):
|
||||
node_d_meta = self.dag_ir.get_node_meta("D")
|
||||
element_D = node_d_meta.store_tensor.element
|
||||
self.get_producer("D", element_D)
|
||||
@@ -0,0 +1,93 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Infer the underlying implement of each node.
|
||||
|
||||
While the frontend only distinguish between Load/Store/Compute Node,
|
||||
each of these nodes can have different underlying implementation based
|
||||
on their layout. For instance, a LoadNode can be AuxLoad, Row/Col/Scalar broadcast, etc.
|
||||
This pass infers the underlying impl of each node
|
||||
"""
|
||||
|
||||
import cutlass_api.fusion.backend as evt_backend
|
||||
from cutlass_api.fusion.ir import DAGIR, LoadNode
|
||||
from cutlass_api.fusion.passes.pass_fix_element_d import PassFixElementD
|
||||
from cutlass_api.fusion.passes.pass_manager import EVTPassBase
|
||||
from cutlass_api.fusion.passes.pass_no_op_elimination import PassNoOpElimination
|
||||
from cutlass_api.fusion.passes.pass_shape_type_propagation import (
|
||||
PassShapeTypePropagation,
|
||||
)
|
||||
from cutlass_api.fusion.passes.util import cc_map
|
||||
|
||||
|
||||
class PassGetImpl(EVTPassBase):
|
||||
"""
|
||||
While the frontend only distinguish between Load/Store/Compute Node,
|
||||
each of these nodes can have different underlying implementation based
|
||||
on their layout. For instance, a LoadNode can be AuxLoad, Row/Col/Scalar broadcast, etc.
|
||||
This pass infers the underlying impl of each node
|
||||
"""
|
||||
|
||||
dependencies = [
|
||||
PassShapeTypePropagation, # The shape and type info are required for inference
|
||||
PassFixElementD,
|
||||
]
|
||||
|
||||
def __init__(self, dag_ir: DAGIR) -> None:
|
||||
super().__init__(dag_ir)
|
||||
self.no_op_elimination = PassNoOpElimination(dag_ir)
|
||||
|
||||
def requires(self) -> None:
|
||||
# Verify "accum" is in the arg list
|
||||
if not self.dag_ir.has_node("accum"):
|
||||
raise SyntaxError("Cannot find 'accum' in the argument list.")
|
||||
|
||||
def call(self):
|
||||
# The loop structure of the epilogue is determined by the
|
||||
# accumulator shape
|
||||
accumulator: LoadNode = self.dag_ir.get_node_meta("accum")
|
||||
problem_size = accumulator.tensor.shape
|
||||
|
||||
for node_meta in self.dag_ir.node_metas_topological_order():
|
||||
node_meta.get_underlying_impl(problem_size)
|
||||
|
||||
def ensures(self) -> None:
|
||||
# Some nodes will be lowered to NoOp, eliminate them
|
||||
self.no_op_elimination()
|
||||
# Lower to cc-specific impl
|
||||
for node_meta in self.dag_ir.nodes_meta:
|
||||
node_impl_ccs = getattr(evt_backend, f"sm{cc_map[self.cc]}_nodes")
|
||||
node_meta.underlying_impl = getattr(
|
||||
node_impl_ccs,
|
||||
f"Sm{cc_map[self.cc]}" + node_meta.underlying_impl.__class__.__name__,
|
||||
)(node_meta)
|
||||
@@ -0,0 +1,230 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Eliminate layout manipulation nodes
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
from cutlass_api.fusion.ir import DAGIR, LayoutNode
|
||||
from cutlass_api.fusion.passes.pass_manager import EVTPassBase
|
||||
from cutlass_api.fusion.passes.pass_shape_type_propagation import (
|
||||
PassShapeTypePropagation,
|
||||
)
|
||||
|
||||
|
||||
class PassLayoutManipulateElimination(EVTPassBase):
|
||||
"""
|
||||
Eliminate layout manipulation nodes
|
||||
"""
|
||||
|
||||
dependencies = [PassShapeTypePropagation]
|
||||
|
||||
def __init__(self, dag_ir: DAGIR) -> None:
|
||||
super().__init__(dag_ir)
|
||||
self.copy_cnt = 0
|
||||
|
||||
def call(self):
|
||||
self.layout_nodes_worklist = self.get_all_layout_nodes()
|
||||
# Run while loop utill all layout nodes are eliminated
|
||||
while len(self.layout_nodes_worklist) > 0:
|
||||
node = self.layout_nodes_worklist.pop(0)
|
||||
# for node in layout_nodes:
|
||||
# Step 1: get the propagation direction
|
||||
direction = self.get_propagation_direction(node)
|
||||
self.visited = []
|
||||
getattr(self, f"propagate_to_{direction}")(
|
||||
self.dag_ir.get_node_meta(node), node
|
||||
)
|
||||
# Eliminate the current node
|
||||
input_node = self.dag_ir.get_all_inputs(node)[0]
|
||||
self.dag_ir.replace_all_uses_with(node, input_node)
|
||||
# layout_nodes = self.get_all_layout_nodes()
|
||||
|
||||
def get_all_layout_nodes(self):
|
||||
layout_nodes = []
|
||||
for node_meta in reversed(self.dag_ir.node_metas_topological_order()):
|
||||
if isinstance(node_meta, LayoutNode):
|
||||
layout_nodes.append(node_meta.name)
|
||||
return layout_nodes
|
||||
|
||||
def get_propagation_direction(self, node: str):
|
||||
"""
|
||||
The logic is propagating all layout nodes away from the accumulator node.
|
||||
"""
|
||||
self.visited = []
|
||||
self.get_influenced_users(node)
|
||||
nodes_influenced_dir_users = self.visited
|
||||
self.visited = []
|
||||
self.get_influenced_inputs(node)
|
||||
nodes_influenced_dir_inputs = self.visited
|
||||
|
||||
if (
|
||||
"accum" in nodes_influenced_dir_users
|
||||
and "accum" not in nodes_influenced_dir_inputs
|
||||
):
|
||||
return "inputs"
|
||||
elif (
|
||||
"accum" not in nodes_influenced_dir_users
|
||||
and "accum" in nodes_influenced_dir_inputs
|
||||
):
|
||||
return "users"
|
||||
else:
|
||||
raise RuntimeError("Unsolved propagation direction")
|
||||
|
||||
# Get all influenced nodes if we propagate along the user direction
|
||||
def get_influenced_users(self, node: str):
|
||||
if node in self.visited:
|
||||
return
|
||||
self.visited.append(node)
|
||||
|
||||
users = self.dag_ir.get_users(node)
|
||||
for user in users:
|
||||
self.get_influenced_users(user)
|
||||
user_inputs = []
|
||||
for user in users:
|
||||
user_inputs.append(set(self.dag_ir.get_all_inputs(user)))
|
||||
if len(user_inputs) > 0:
|
||||
user_inputs = set.union(*user_inputs)
|
||||
user_inputs.remove(node)
|
||||
for input_node in user_inputs:
|
||||
self.get_influenced_inputs(input_node)
|
||||
|
||||
# Get all influenced nodes if we propagate along the input direction
|
||||
def get_influenced_inputs(self, node: str):
|
||||
if node in self.visited:
|
||||
return
|
||||
self.visited.append(node)
|
||||
|
||||
inputs = self.dag_ir.get_all_inputs(node)
|
||||
for input_node in inputs:
|
||||
self.get_influenced_inputs(input_node)
|
||||
input_users = []
|
||||
for input_node in inputs:
|
||||
input_users.append(set(self.dag_ir.get_users(input_node)))
|
||||
if len(input_users) > 0:
|
||||
input_users = set.union(*input_users)
|
||||
input_users.remove(node)
|
||||
for user in input_users:
|
||||
self.get_influenced_users(user)
|
||||
|
||||
def add_copy_before(self, layout_node_meta: LayoutNode, target: str):
|
||||
copied_node_meta = deepcopy(layout_node_meta)
|
||||
copied_node = f"{copied_node_meta.name}_copy{self.copy_cnt}"
|
||||
self.copy_cnt += 1
|
||||
copied_node_meta.name = copied_node
|
||||
self.dag_ir.add_node(copied_node_meta)
|
||||
# Add edges
|
||||
target_inputs = self.dag_ir.get_all_inputs(target)
|
||||
for src in target_inputs:
|
||||
self.dag_ir.remove_edge(src, target)
|
||||
self.dag_ir.add_edge(src, copied_node)
|
||||
self.dag_ir.add_edge(copied_node, target)
|
||||
self.layout_nodes_worklist.append(copied_node)
|
||||
|
||||
def add_copy_after(self, layout_node_meta: LayoutNode, target: str):
|
||||
copied_node_meta = deepcopy(layout_node_meta)
|
||||
copied_node = f"{copied_node_meta.name}_copy{self.copy_cnt}"
|
||||
self.copy_cnt += 1
|
||||
copied_node_meta.name = copied_node
|
||||
self.dag_ir.add_node(copied_node_meta)
|
||||
# Add edges
|
||||
users = self.dag_ir.get_users(target)
|
||||
for user in users:
|
||||
self.dag_ir.remove_edge(target, user)
|
||||
self.dag_ir.add_edge(copied_node, user)
|
||||
self.dag_ir.add_edge(target, copied_node)
|
||||
self.layout_nodes_worklist.append(copied_node)
|
||||
|
||||
# Propagate the layout `node` along the user direction
|
||||
def propagate_to_users(self, layout_node_meta: LayoutNode, node: str):
|
||||
"""
|
||||
Propagate layout node to users
|
||||
"""
|
||||
if node in self.visited:
|
||||
# Avoid applying twice
|
||||
return
|
||||
self.visited.append(node)
|
||||
|
||||
node_meta = self.dag_ir.get_node_meta(node)
|
||||
if layout_node_meta.name != node:
|
||||
if isinstance(node_meta, LayoutNode):
|
||||
# Layout node is not transparent with layout node
|
||||
self.add_copy_before(layout_node_meta, node)
|
||||
return
|
||||
else:
|
||||
layout_node_meta.apply_to_user(node_meta)
|
||||
|
||||
users = self.dag_ir.get_users(node)
|
||||
user_inputs = []
|
||||
for user in users:
|
||||
user_inputs.append(set(self.dag_ir.get_all_inputs(user)))
|
||||
for user in users:
|
||||
self.propagate_to_users(layout_node_meta, user)
|
||||
if len(user_inputs) > 0:
|
||||
user_inputs = set.union(*user_inputs)
|
||||
user_inputs.remove(node)
|
||||
for input_node in user_inputs:
|
||||
self.propagate_to_inputs(
|
||||
layout_node_meta.get_inverse_node(), input_node
|
||||
)
|
||||
|
||||
# Propagate the layout `node` along the input direction
|
||||
def propagate_to_inputs(self, layout_node_meta: LayoutNode, node: str):
|
||||
"""
|
||||
Propagate layout node to inputs
|
||||
"""
|
||||
if node in self.visited:
|
||||
# Avoid applying twice
|
||||
return
|
||||
self.visited.append(node)
|
||||
|
||||
node_meta = self.dag_ir.get_node_meta(node)
|
||||
if layout_node_meta.name != node:
|
||||
if isinstance(node_meta, LayoutNode):
|
||||
# Layout node is not transparent with layout node
|
||||
self.add_copy_after(layout_node_meta, node)
|
||||
return
|
||||
else:
|
||||
layout_node_meta.apply_to_input(node_meta)
|
||||
inputs = self.dag_ir.get_all_inputs(node)
|
||||
input_users = []
|
||||
for input_node in inputs:
|
||||
input_users.append(set(self.dag_ir.get_users(input_node)))
|
||||
for input_node in inputs:
|
||||
self.propagate_to_inputs(layout_node_meta, input_node)
|
||||
if len(input_users) > 0:
|
||||
input_users = set.union(*input_users)
|
||||
input_users.remove(node)
|
||||
for user in input_users:
|
||||
self.propagate_to_users(layout_node_meta.get_inverse_node(), user)
|
||||
185
python/cutlass_api/cutlass_api/fusion/passes/pass_manager.py
Normal file
185
python/cutlass_api/cutlass_api/fusion/passes/pass_manager.py
Normal file
@@ -0,0 +1,185 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Pass manager for DAG IR.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from cutlass_api.fusion.ir import DAGIR
|
||||
from cutlass_api.fusion.passes.util import cc_map
|
||||
|
||||
|
||||
class EVTPassBase:
|
||||
"""
|
||||
Base class for EVT Passes
|
||||
"""
|
||||
|
||||
dependencies = []
|
||||
|
||||
def __init__(self, dag_ir: DAGIR) -> None:
|
||||
self.dag_ir = dag_ir
|
||||
self.cc = self.dag_ir.cc
|
||||
|
||||
def requires(self) -> None:
|
||||
"""
|
||||
This function will be called before the pass is run.
|
||||
"""
|
||||
pass
|
||||
|
||||
def call(self) -> None:
|
||||
"""
|
||||
The pass that is run through the self.dag_ir
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f"__call__ is not overwritten in Pass {self.__class__.__name__}"
|
||||
)
|
||||
|
||||
def ensures(self) -> None:
|
||||
"""
|
||||
This function will be called after the pass is run.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __call__(self) -> Any:
|
||||
self.requires()
|
||||
self.call()
|
||||
self.ensures()
|
||||
|
||||
def cc_specific_method(self, func):
|
||||
"""
|
||||
This enables defining function that behaves differently under different cc
|
||||
The simplest example of using this function is the following
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
class ExamplePass(EVTPassBase):
|
||||
|
||||
def call(sekf):
|
||||
# This automatically select the smXX_func based on current cc
|
||||
self.cc_specific_method(self.func)()
|
||||
|
||||
# Interface func, can be empty
|
||||
def func(self):
|
||||
pass
|
||||
|
||||
# Sm90 specific func
|
||||
def sm90_func(self):
|
||||
// sm90 specific method
|
||||
return
|
||||
|
||||
# Sm80 specific func
|
||||
def sm80_func(self):
|
||||
// sm80 specific method
|
||||
return
|
||||
"""
|
||||
func_name = f"sm{cc_map[self.cc]}_{func.__name__}"
|
||||
if hasattr(self, func_name):
|
||||
return getattr(self, func_name)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"func {func.__name__} is not overwritten for Sm{self.cc}"
|
||||
)
|
||||
|
||||
|
||||
class EVTPassManager(nx.DiGraph):
|
||||
"""
|
||||
Topological-based Pass Manager.
|
||||
Each registered pass has a list of dependencies. The pass manager organizes
|
||||
the passes as a DAG and launch the compiler passes under topological order.
|
||||
"""
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
# NetworkX 3.0+ changed DiGraph.__new__ to accept fewer arguments.
|
||||
# Override to accept and ignore extra arguments from __init__.
|
||||
return super().__new__(cls)
|
||||
|
||||
def __init__(self, dag_ir: DAGIR, pass_list, soft_dependencies=None):
|
||||
super().__init__()
|
||||
self.dag_ir = dag_ir
|
||||
for pass_cls in pass_list:
|
||||
self.add_pass(pass_cls)
|
||||
|
||||
self.passes = pass_list
|
||||
self.soft_dependencies = soft_dependencies
|
||||
if self.soft_dependencies is None:
|
||||
self.soft_dependencies = []
|
||||
|
||||
self.sorted_passes = self.schedule()
|
||||
|
||||
def get_callable(self, pass_name):
|
||||
"""
|
||||
Return the callable of the pass
|
||||
"""
|
||||
return self.nodes[pass_name]["callable"]
|
||||
|
||||
def add_pass(self, pass_cls):
|
||||
"""
|
||||
Add a pass to the pass manager
|
||||
:param pass_cls: the class of pass
|
||||
:type pass_cls: derived class of EVTPassBase
|
||||
"""
|
||||
name = pass_cls.__name__
|
||||
pass_callable = pass_cls(self.dag_ir)
|
||||
self.add_node(name, callable=pass_callable)
|
||||
|
||||
def schedule(self):
|
||||
"""
|
||||
Schedule the added passes under topological order
|
||||
"""
|
||||
# Add edges
|
||||
for pass_name in self.nodes:
|
||||
callable = self.get_callable(pass_name)
|
||||
for dependency_cls in callable.dependencies:
|
||||
if dependency_cls not in self.passes:
|
||||
if dependency_cls not in self.soft_dependencies:
|
||||
raise ValueError(
|
||||
f"Pass {pass_name} depends on {dependency_cls.__name__}, which is not in the pass list"
|
||||
)
|
||||
else:
|
||||
continue
|
||||
self.add_edge(dependency_cls.__name__, type(callable).__name__)
|
||||
|
||||
# Topological sort
|
||||
return list(nx.topological_sort(self))
|
||||
|
||||
def __call__(self) -> Any:
|
||||
"""
|
||||
Launch the registered passes
|
||||
"""
|
||||
for pass_name in self.sorted_passes:
|
||||
callable = self.get_callable(pass_name)
|
||||
callable()
|
||||
@@ -0,0 +1,59 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
No op elimination node
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from cutlass_api.fusion.ir import NoOpImpl
|
||||
from cutlass_api.fusion.passes.pass_manager import EVTPassBase
|
||||
|
||||
|
||||
class PassNoOpElimination(EVTPassBase):
|
||||
"""
|
||||
The dead node elimination pass removes nodes with NoOpImpl in DAG IR
|
||||
"""
|
||||
|
||||
dependencies = []
|
||||
|
||||
def call(self) -> Any:
|
||||
for node in self.dag_ir.nodes_topological_order():
|
||||
node_meta = self.dag_ir.get_node_meta(node)
|
||||
if isinstance(node_meta.underlying_impl, NoOpImpl):
|
||||
inputs = self.dag_ir.get_all_inputs(node)
|
||||
if len(inputs) != 1:
|
||||
raise RuntimeError(
|
||||
f"NoOp node {node} has {len(inputs)} inputs, expected 1"
|
||||
)
|
||||
self.dag_ir.replace_all_uses_with(node, inputs[0])
|
||||
@@ -0,0 +1,96 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Preprocess the reduction nodes.
|
||||
|
||||
The parser treats reduction as Compute(op=(reg_reduce_fn, gmem_reduce_fn)) - Store()
|
||||
This pass fuses these into a single store node, and then replaces all uses of the
|
||||
current node with the new store node.
|
||||
"""
|
||||
|
||||
from cutlass_api.fusion.ir import ComputeNode, StoreNode
|
||||
from cutlass_api.fusion.passes.pass_manager import EVTPassBase
|
||||
|
||||
|
||||
class PassPreprocessRed(EVTPassBase):
|
||||
"""
|
||||
Preprocess red nodes
|
||||
"""
|
||||
|
||||
def call(self):
|
||||
# Step 1: find the compute nodes with op=red
|
||||
red_compute_nodes = []
|
||||
for node_meta in self.dag_ir.nodes_meta:
|
||||
if isinstance(node_meta, ComputeNode) and type(node_meta.fn) == tuple:
|
||||
# To keep the frontend simple, the reduction nodes
|
||||
# are parsed into compute nodes by default
|
||||
# The simple heuristic to distinguish between compute
|
||||
# and reduction node is that compute node is a single function,
|
||||
# while the reduction node is a tuple of functions for
|
||||
# in-register reduction and atomic global memory reduction
|
||||
red_compute_nodes.append(node_meta.name)
|
||||
|
||||
# Step 2: for each compute, merge it with the succeeding store
|
||||
for node in red_compute_nodes:
|
||||
# Verify
|
||||
users = self.dag_ir.get_users(node)
|
||||
inputs = self.dag_ir.get_all_inputs(node)
|
||||
# Has a single user
|
||||
assert len(users) == 1
|
||||
assert len(inputs) == 1
|
||||
user = users[0]
|
||||
input_node = inputs[0]
|
||||
|
||||
user_meta = self.dag_ir.get_node_meta(user)
|
||||
# Must be a store node
|
||||
assert isinstance(user_meta, StoreNode)
|
||||
# With output degree == 0
|
||||
assert self.dag_ir.out_degree(user) == 0
|
||||
# Register the reduce op
|
||||
node_meta = self.dag_ir.get_node_meta(node)
|
||||
user_meta.reg_reduce_fn, user_meta.gmem_reduce_fn = node_meta.fn
|
||||
user_meta.element_compute = node_meta.element_compute
|
||||
user_meta.round_style = node_meta.round_style
|
||||
|
||||
# Replace all uses
|
||||
self.dag_ir.remove_edge(input_node, node)
|
||||
input_users = self.dag_ir.get_users(input_node)
|
||||
for iu in input_users:
|
||||
weight = self.dag_ir.get_edge_weight(input_node, iu)
|
||||
self.dag_ir.add_edge(user, iu, weight)
|
||||
self.dag_ir.remove_edge(input_node, iu)
|
||||
self.dag_ir.add_edge(input_node, user)
|
||||
self.dag_ir.remove_node(node)
|
||||
|
||||
# Register the reduction name
|
||||
self.dag_ir.reduction_names.append(user)
|
||||
@@ -0,0 +1,60 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Shape and type propagation pass
|
||||
"""
|
||||
|
||||
from cutlass_api.fusion.ir.node import NodeBase
|
||||
from cutlass_api.fusion.passes.pass_manager import EVTPassBase
|
||||
from cutlass_api.fusion.passes.pass_preprocess_red import PassPreprocessRed
|
||||
|
||||
|
||||
class PassShapeTypePropagation(EVTPassBase):
|
||||
"""
|
||||
Propagate the shape and type of all nodes
|
||||
"""
|
||||
|
||||
dependencies = [PassPreprocessRed]
|
||||
|
||||
def call(self):
|
||||
# Propagate the node shape and type
|
||||
for node in self.dag_ir.nodes_topological_order():
|
||||
node_meta: NodeBase = self.dag_ir.get_node_meta(node)
|
||||
input_node_metas = self.dag_ir.get_all_inputs_meta(node)
|
||||
node_meta.type_propagation(input_node_metas)
|
||||
node_meta.shape_propagation(input_node_metas)
|
||||
|
||||
for node in reversed(self.dag_ir.nodes_topological_order()):
|
||||
node_meta: NodeBase = self.dag_ir.get_node_meta(node)
|
||||
input_node_metas = self.dag_ir.get_all_inputs_meta(node)
|
||||
node_meta.broadcast_propagation(input_node_metas)
|
||||
@@ -0,0 +1,363 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Compute the shared memory size in bytes
|
||||
"""
|
||||
|
||||
from math import gcd
|
||||
|
||||
from cutlass_api.fusion.library import DataType, DataTypeSize, EpilogueScheduleType
|
||||
from cutlass_api.fusion.ir import TopoVisitorNode, DAGIR
|
||||
from cutlass_api.fusion.pycute import flatten, shape_div, product
|
||||
|
||||
|
||||
class GetSmemSize:
|
||||
"""
|
||||
Get the size in byte of shared memory used by the kernel
|
||||
"""
|
||||
|
||||
def __init__(self, dag_ir: DAGIR) -> None:
|
||||
self.dag_ir = dag_ir
|
||||
self.cc = self.dag_ir.cc
|
||||
|
||||
#
|
||||
# Sm90 epilogue specific
|
||||
#
|
||||
|
||||
def sm90_epilogue_tile(self, tile_description):
|
||||
# Get the epilogue tile size
|
||||
schedule = tile_description.epilogue_schedule
|
||||
if schedule == EpilogueScheduleType.TmaWarpSpecialized:
|
||||
element_d = self.dag_ir.get_node_meta("D").element
|
||||
nperf = (
|
||||
64
|
||||
if (
|
||||
DataTypeSize[element_d] == 8
|
||||
and tile_description.threadblock_shape[1] % 64 == 0
|
||||
)
|
||||
else 32
|
||||
)
|
||||
epi_tile_m = min(64, tile_description.threadblock_shape[0])
|
||||
epi_tile_n = gcd(
|
||||
min(nperf, tile_description.threadblock_shape[1]),
|
||||
tile_description.threadblock_shape[1],
|
||||
)
|
||||
epilogue_tile_mn = (epi_tile_m, epi_tile_n)
|
||||
elif schedule == EpilogueScheduleType.TmaWarpSpecializedCooperative:
|
||||
epi_tile_m = min(128, tile_description.threadblock_shape[0])
|
||||
epi_tile_n = gcd(
|
||||
min(32, tile_description.threadblock_shape[1]),
|
||||
tile_description.threadblock_shape[1],
|
||||
)
|
||||
epilogue_tile_mn = (epi_tile_m, epi_tile_n)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported schedule: {schedule}")
|
||||
|
||||
# Get the pipeline stages
|
||||
stages_d = 2
|
||||
epi_tiles = product(
|
||||
shape_div(tuple(tile_description.threadblock_shape)[:2], epilogue_tile_mn)
|
||||
)
|
||||
if self.dag_ir.has_node("C"):
|
||||
element_c = self.dag_ir.get_node_meta("C").element
|
||||
else:
|
||||
element_c = None
|
||||
|
||||
element_d = self.dag_ir.get_node_meta("D").element
|
||||
if element_c == element_d:
|
||||
reuse_smem_c = True
|
||||
else:
|
||||
reuse_smem_c = False
|
||||
stages_c = max(epi_tiles, stages_d + 1) if reuse_smem_c else epi_tiles
|
||||
|
||||
# Record the epilogue tile
|
||||
self.cta_tile_mnk = tuple(tile_description.threadblock_shape)
|
||||
self.epilogue_tile_mn = epilogue_tile_mn
|
||||
self.epi_tiles = epi_tiles
|
||||
self.stages_c = stages_c
|
||||
self.stages_d = stages_d
|
||||
self.reuse_smem_c = reuse_smem_c
|
||||
self.element_c = element_c
|
||||
self.element_d = element_d
|
||||
self.is_source_supported = element_c is not None
|
||||
|
||||
def sm90_or_sm100_epilogue_smem_size(self, tile_description):
|
||||
# Get the Fusion Storage
|
||||
nodes = self.dag_ir.nodes_topological_order()
|
||||
self.smem_types = {}
|
||||
for node in nodes:
|
||||
meta = self.dag_ir.get_node_meta(node)
|
||||
if not meta.disabled:
|
||||
self.smem_types[node] = meta.underlying_impl.get_smem_size(
|
||||
self.cta_tile_mnk,
|
||||
self.epilogue_tile_mn,
|
||||
self.stages_c,
|
||||
self.stages_d,
|
||||
self.epi_tiles,
|
||||
)
|
||||
if node == "D":
|
||||
continue
|
||||
if isinstance(meta, TopoVisitorNode):
|
||||
self.get_dag_smem_type(node)
|
||||
else:
|
||||
self.get_evt_smem_type(node)
|
||||
|
||||
thread_smem_size = self.smem_types[self.dag_ir.get_all_inputs("D")[0]][0]
|
||||
# Get the Tensor Storage
|
||||
tensors = []
|
||||
if self.is_source_supported:
|
||||
smem_C = (
|
||||
DataTypeSize[self.element_c]
|
||||
* product(self.epilogue_tile_mn)
|
||||
* self.stages_c
|
||||
// 8
|
||||
)
|
||||
tensors.append((smem_C, 128))
|
||||
else:
|
||||
tensors.append((0, 1))
|
||||
if self.reuse_smem_c:
|
||||
tensors.append((0, 128))
|
||||
else:
|
||||
smem_D = (
|
||||
DataTypeSize[self.element_d]
|
||||
* product(self.epilogue_tile_mn)
|
||||
* self.stages_d
|
||||
// 8
|
||||
)
|
||||
tensors.append((smem_D, 128))
|
||||
tensors.append((thread_smem_size, 128))
|
||||
|
||||
tensor_smem_size = self.get_struct_size(tensors)
|
||||
# Get pipeline storage size
|
||||
# sizeof(uint64_t * stages_c * 2), alignment of uint64_t
|
||||
# 2 is for FullBarrier and EmptyBarrier
|
||||
pipeline_smem_size = (8 * self.stages_c * 2, 8)
|
||||
|
||||
# get SharedStorage size
|
||||
smem_size = self.get_struct_size([tensor_smem_size, pipeline_smem_size])
|
||||
return smem_size[0]
|
||||
|
||||
def sm90_epilogue_smem_size(self, tile_description):
|
||||
"""
|
||||
Compute the shared memory size of sm90 collective epilogue
|
||||
"""
|
||||
self.sm90_epilogue_tile(tile_description)
|
||||
return self.sm90_or_sm100_epilogue_smem_size(tile_description)
|
||||
|
||||
#
|
||||
# Sm100 epilogue specific
|
||||
#
|
||||
|
||||
def sm100_epilogue_tile(self, tile_description):
|
||||
cta_tile = (
|
||||
tile_description.blackwell_threadblock_shape[0],
|
||||
tile_description.blackwell_threadblock_shape[1],
|
||||
)
|
||||
mma_tile = cta_tile
|
||||
|
||||
if tile_description.is_2sm:
|
||||
cta_tile = (cta_tile[0] // 2, cta_tile[1])
|
||||
|
||||
if tile_description.is_2sm and mma_tile[0] == 128:
|
||||
tmem_warps = (2, 2)
|
||||
else:
|
||||
tmem_warps = (4, 1)
|
||||
|
||||
if self.dag_ir.has_node("C"):
|
||||
element_c = self.dag_ir.get_node_meta("C").element
|
||||
element_c_size = DataTypeSize[element_c]
|
||||
else:
|
||||
element_c = None
|
||||
element_c_size = 0
|
||||
|
||||
element_d = self.dag_ir.get_node_meta("D").element
|
||||
|
||||
DisableSource = (
|
||||
element_c is None
|
||||
or not self.dag_ir.has_node("C")
|
||||
or self.dag_ir.get_node_meta("C").element == DataType.void
|
||||
)
|
||||
|
||||
CtaM = cta_tile[0]
|
||||
CtaN = cta_tile[1]
|
||||
WarpM = tmem_warps[0]
|
||||
WarpN = tmem_warps[1]
|
||||
MaxBits = max(element_c_size, DataTypeSize[element_d])
|
||||
DpFull = 32
|
||||
M = min(CtaM, DpFull * WarpM)
|
||||
|
||||
if DisableSource:
|
||||
# Epilogues w/o residual load are less sensitive to smem allocation
|
||||
# Target a fixed amount of compute per epilogue iteration
|
||||
if MaxBits == 4:
|
||||
# Make epilogue tile larger to reduce the epilogue iterations.
|
||||
# 64 is the experimental value. It will minimize epilogue iterations but keep the number of A/B buffers the same.
|
||||
ComputeElts = 8192
|
||||
Nperf = ComputeElts // M
|
||||
else:
|
||||
ComputeElts = 4096
|
||||
Nperf = ComputeElts // M
|
||||
else:
|
||||
# Epilogues w/ residual load are more sensitive to smem allocation
|
||||
# Target optimal smem distribution between epilogue+mainloop based on datatype+tilesize
|
||||
if MaxBits == 32:
|
||||
Nperf = 16 if CtaM > 64 and CtaN <= 128 else 32
|
||||
elif MaxBits == 16:
|
||||
Nperf = 32 if CtaN <= 128 else 64
|
||||
else:
|
||||
Nperf = 64
|
||||
|
||||
def is_m_major(layout):
|
||||
return flatten(layout.stride[0]) == 1
|
||||
|
||||
if DisableSource or is_m_major(self.dag_ir.get_node_meta("C").tensor.layout):
|
||||
N_min_C = 8 * WarpN
|
||||
elif element_c_size == 6:
|
||||
N_min_C = 128 * WarpN
|
||||
else:
|
||||
N_min_C = (128 // element_c_size) * WarpN
|
||||
|
||||
if is_m_major(self.dag_ir.get_node_meta("D").tensor.layout):
|
||||
N_min_D = 8 * WarpN
|
||||
elif DataTypeSize[element_d] == 6:
|
||||
N_min_D = 128 * WarpN
|
||||
else:
|
||||
N_min_D = (128 // DataTypeSize[element_d]) * WarpN
|
||||
|
||||
N = min(CtaN, max(Nperf, N_min_C, N_min_D))
|
||||
|
||||
tile_m = M
|
||||
tile_n_size = N // WarpN * WarpN
|
||||
|
||||
epilogue_tile_mn = (tile_m, tile_n_size)
|
||||
epi_tiles = product(
|
||||
shape_div(tuple(tile_description.threadblock_shape)[:2], epilogue_tile_mn)
|
||||
)
|
||||
|
||||
stages_d = min(epi_tiles, 2)
|
||||
reuse_smem_c = element_c_size > 8
|
||||
|
||||
if reuse_smem_c:
|
||||
stages_c = max(min(epi_tiles, 4), stages_d + 1)
|
||||
else:
|
||||
stages_c = min(epi_tiles, 4)
|
||||
|
||||
# Record the epilogue tile
|
||||
self.cta_tile_mnk = tuple(tile_description.threadblock_shape)
|
||||
self.epilogue_tile_mn = epilogue_tile_mn
|
||||
self.epi_tiles = epi_tiles
|
||||
self.stages_c = stages_c
|
||||
self.stages_d = stages_d
|
||||
self.reuse_smem_c = reuse_smem_c
|
||||
self.element_c = element_c
|
||||
self.element_d = element_d
|
||||
self.is_source_supported = not DisableSource
|
||||
|
||||
def sm100_epilogue_smem_size(self, tile_description):
|
||||
"""
|
||||
Compute the shared memory size of sm100 collective epilogue
|
||||
"""
|
||||
self.sm100_epilogue_tile(tile_description)
|
||||
return self.sm90_or_sm100_epilogue_smem_size(tile_description)
|
||||
|
||||
def __call__(self, tile_description):
|
||||
return getattr(self, f"sm{self.cc}_epilogue_smem_size")(tile_description)
|
||||
|
||||
#
|
||||
# Helper functions
|
||||
#
|
||||
|
||||
@staticmethod
|
||||
def get_visitor_size(members: list, ebo: bool):
|
||||
"""
|
||||
Get the size of struct in bytes
|
||||
"""
|
||||
offset = 0
|
||||
max_alignment = 1
|
||||
if len(members) > 0:
|
||||
# Get alignment
|
||||
for _, alignment in members:
|
||||
max_alignment = max(max_alignment, alignment)
|
||||
|
||||
for type_size, _ in members:
|
||||
if type_size != 0:
|
||||
offset = (
|
||||
(offset + max_alignment - 1) // max_alignment
|
||||
) * max_alignment
|
||||
if type_size == 0 and not ebo:
|
||||
offset += 1
|
||||
else:
|
||||
offset += type_size
|
||||
offset = ((offset + max_alignment - 1) // max_alignment) * max_alignment
|
||||
return (offset, max_alignment)
|
||||
else:
|
||||
# Struct size is at least 1
|
||||
return (1, 1)
|
||||
|
||||
def get_struct_size(self, members: list):
|
||||
"""
|
||||
Get the size of struct in bytes
|
||||
"""
|
||||
return self.get_visitor_size(members, False)
|
||||
|
||||
def get_evt_smem_type(self, node):
|
||||
# Sort the input nodes by edge weight
|
||||
input_types = [
|
||||
self.smem_types[child] for child in self.dag_ir.get_all_inputs(node)
|
||||
]
|
||||
input_types.append(self.smem_types[node])
|
||||
if len(input_types) > 1:
|
||||
ebo = len(input_types) > 4
|
||||
self.smem_types[node] = self.get_visitor_size(input_types, ebo)
|
||||
|
||||
def get_dag_smem_type(self, node):
|
||||
meta = self.dag_ir.get_node_meta(node)
|
||||
subgraph = meta.subgraph
|
||||
subgraph_nodes = subgraph.nodes_topological_order()
|
||||
# Visit the unvisited nodes in subgraph
|
||||
for n in subgraph_nodes:
|
||||
M = subgraph.get_node_meta(n)
|
||||
if M.disabled:
|
||||
continue
|
||||
else:
|
||||
self.smem_types[n] = M.underlying_impl.get_smem_size(
|
||||
self.cta_tile_mnk,
|
||||
self.epilogue_tile_mn,
|
||||
self.stages_c,
|
||||
self.stages_d,
|
||||
self.epi_tiles,
|
||||
)
|
||||
input_types = [self.smem_types[child] for child in subgraph_nodes[:-1]]
|
||||
if len(input_types) > 0:
|
||||
ebo = len(input_types) > 4
|
||||
self.smem_types[node] = self.get_visitor_size(input_types, ebo)
|
||||
46
python/cutlass_api/cutlass_api/fusion/passes/util.py
Normal file
46
python/cutlass_api/cutlass_api/fusion/passes/util.py
Normal file
@@ -0,0 +1,46 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Utilities for passes
|
||||
"""
|
||||
|
||||
# Map from the CC of the kernel to the EVT implementation that the CC targets
|
||||
cc_map = {
|
||||
80: 80,
|
||||
86: 80,
|
||||
89: 80,
|
||||
90: 90,
|
||||
100: 100,
|
||||
101: 100,
|
||||
103: 100,
|
||||
}
|
||||
36
python/cutlass_api/cutlass_api/fusion/pycute/__init__.py
Normal file
36
python/cutlass_api/cutlass_api/fusion/pycute/__init__.py
Normal file
@@ -0,0 +1,36 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from .int_tuple import *
|
||||
from .layout import *
|
||||
from .swizzle import *
|
||||
from .typing import *
|
||||
229
python/cutlass_api/cutlass_api/fusion/pycute/int_tuple.py
Normal file
229
python/cutlass_api/cutlass_api/fusion/pycute/int_tuple.py
Normal file
@@ -0,0 +1,229 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Functions for manipulating IntTuples
|
||||
"""
|
||||
|
||||
from functools import reduce
|
||||
from itertools import chain
|
||||
from typing import Union
|
||||
from .typing import Integer
|
||||
|
||||
|
||||
def is_int(x):
|
||||
return isinstance(x, Integer)
|
||||
|
||||
|
||||
def is_tuple(x):
|
||||
return isinstance(x, tuple)
|
||||
|
||||
|
||||
def flatten(t):
|
||||
if is_tuple(t):
|
||||
if len(t) == 0:
|
||||
return ()
|
||||
else:
|
||||
return tuple(i for a in t for i in flatten(a))
|
||||
else:
|
||||
return (t,)
|
||||
|
||||
|
||||
def signum(a):
|
||||
return bool(a > 0) - bool(a < 0)
|
||||
|
||||
|
||||
def product(a):
|
||||
if is_tuple(a):
|
||||
return reduce(lambda val, elem: val * product(elem), a, 1)
|
||||
else:
|
||||
return a
|
||||
|
||||
|
||||
def inner_product(a, b):
|
||||
if is_tuple(a): # tuple tuple
|
||||
assert len(a) == len(b)
|
||||
return sum(inner_product(x, y) for x, y in zip(a, b))
|
||||
else: # "int" "int"
|
||||
assert not is_tuple(b)
|
||||
return a * b
|
||||
|
||||
|
||||
def tuple_max(a):
|
||||
if is_tuple(a):
|
||||
return max(tuple_max(x) for x in a)
|
||||
else:
|
||||
return a
|
||||
|
||||
|
||||
def elem_scale(a, b):
|
||||
if is_tuple(a):
|
||||
if is_tuple(b): # tuple tuple
|
||||
assert len(a) == len(b)
|
||||
return tuple(elem_scale(x, y) for x, y in zip(a, b))
|
||||
else: # tuple "int"
|
||||
assert False # Error
|
||||
else:
|
||||
if is_tuple(b): # "int" tuple
|
||||
return elem_scale(a, product(b))
|
||||
else: # "int" "int"
|
||||
return a * b
|
||||
|
||||
|
||||
# Inclusive prefix ceil div with output congruent to input a
|
||||
def shape_div(a, b):
|
||||
if is_tuple(a):
|
||||
if is_tuple(b): # tuple tuple
|
||||
assert len(a) == len(b)
|
||||
return tuple(shape_div(x, y) for x, y in zip(a, b))
|
||||
else: # tuple "int"
|
||||
# r = [shape_div(a[0],b)] + [shape_div(a[i],b := shape_div(b, product(a[i-1]))) for i in range(1,len(a))]
|
||||
r = []
|
||||
for v in a:
|
||||
r.append(shape_div(v, b))
|
||||
b = shape_div(b, product(v))
|
||||
return tuple(r)
|
||||
else:
|
||||
if is_tuple(b): # "int" tuple
|
||||
return shape_div(a, product(b))
|
||||
else: # "int" "int"
|
||||
assert a % b == 0 or b % a == 0
|
||||
return (a + b - 1) // b
|
||||
|
||||
|
||||
# Exclusive prefix product with output congruent to input a
|
||||
def prefix_product(a, init=1):
|
||||
if is_tuple(a):
|
||||
if is_tuple(init): # tuple tuple
|
||||
assert len(a) == len(init)
|
||||
return tuple(prefix_product(x, i) for x, i in zip(a, init))
|
||||
else: # tuple "int"
|
||||
# r = [prefix_product(a[0],init)] + [prefix_product(a[i],init := init * product(a[i-1])) for i in range(1,len(a))]
|
||||
r = []
|
||||
for v in a:
|
||||
r.append(prefix_product(v, init))
|
||||
init = init * product(v)
|
||||
return tuple(r)
|
||||
else:
|
||||
if is_tuple(init): # "int" tuple
|
||||
assert False # Error
|
||||
else: # "int" "int"
|
||||
return init
|
||||
|
||||
|
||||
def idx2crd(idx, shape, stride=None):
|
||||
if stride is None:
|
||||
stride = prefix_product(shape)
|
||||
|
||||
if is_tuple(idx):
|
||||
if is_tuple(shape): # tuple tuple tuple
|
||||
assert len(idx) == len(shape) and len(idx) == len(stride)
|
||||
return tuple(idx2crd(i, s, d) for i, s, d in zip(idx, shape, stride))
|
||||
else: # tuple "int" "int"
|
||||
assert False # Error
|
||||
else:
|
||||
if is_tuple(shape): # "int" tuple tuple
|
||||
assert len(shape) == len(stride)
|
||||
return tuple(idx2crd(idx, s, d) for s, d in zip(shape, stride))
|
||||
else: # "int" "int" "int"
|
||||
return (idx // stride) % shape
|
||||
|
||||
|
||||
def crd2idx(crd, shape, stride=None):
|
||||
if stride is None:
|
||||
stride = prefix_product(shape)
|
||||
|
||||
if is_tuple(crd):
|
||||
if is_tuple(shape): # tuple tuple tuple
|
||||
assert len(crd) == len(shape) and len(crd) == len(stride)
|
||||
return sum(crd2idx(c, s, d) for c, s, d in zip(crd, shape, stride))
|
||||
else: # tuple "int" "int"
|
||||
assert False, f"crd={crd}, shape={shape}" # Error
|
||||
else:
|
||||
if crd is None:
|
||||
crd = 0
|
||||
|
||||
if is_tuple(shape): # "int" tuple tuple
|
||||
assert len(shape) == len(stride)
|
||||
result = 0
|
||||
for i in range(len(shape) - 1):
|
||||
result += crd2idx(crd % product(shape[i]), shape[i], stride[i])
|
||||
crd = crd // product(shape[i])
|
||||
return result + crd2idx(crd, shape[-1], stride[-1])
|
||||
else: # "int" "int" "int"
|
||||
return crd * stride
|
||||
|
||||
|
||||
# Transform crd into the dst_shape's iteration space
|
||||
def crd2crd(crd, dst_shape, src_shape=None):
|
||||
if is_tuple(crd):
|
||||
if is_tuple(dst_shape): # tuple tuple
|
||||
assert len(crd) == len(dst_shape)
|
||||
return tuple(crd2crd(x, y) for x, y in zip(crd, dst_shape))
|
||||
else: # tuple "int"
|
||||
# Ambiguous unless we have src_shape
|
||||
assert src_shape is not None
|
||||
return crd2idx(crd, src_shape)
|
||||
else:
|
||||
if is_tuple(dst_shape): # "int" tuple
|
||||
return idx2crd(crd, dst_shape)
|
||||
else: # "int" "int"
|
||||
assert crd < dst_shape
|
||||
return crd
|
||||
|
||||
|
||||
# Filter trg according to crd: keep only elements of trg that are paired with None
|
||||
def slice_(crd: Union[None, tuple, int], trg: Union[tuple, int]):
|
||||
if is_tuple(crd):
|
||||
if is_tuple(trg): # tuple tuple
|
||||
assert len(crd) == len(trg)
|
||||
# match C++ behavior of `filter_tuple` using `tuple_cat(...)`
|
||||
return tuple(
|
||||
chain(
|
||||
*filter(lambda x: x != (), [slice_(c, s) for c, s in zip(crd, trg)])
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert False # tuple "int" : Error
|
||||
elif crd is None:
|
||||
# match C++ behavior `return cute::tuple<B>{b};`
|
||||
return (trg,)
|
||||
else:
|
||||
return ()
|
||||
|
||||
|
||||
# Determine if None appears at any of an int_tuples' terminals
|
||||
def has_none(a: Union[None, tuple, int]):
|
||||
if is_tuple(a):
|
||||
return any(has_none(v) for v in a)
|
||||
else:
|
||||
return a is None
|
||||
410
python/cutlass_api/cutlass_api/fusion/pycute/layout.py
Normal file
410
python/cutlass_api/cutlass_api/fusion/pycute/layout.py
Normal file
@@ -0,0 +1,410 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Definition of CuTe Layouts and functions to manipulate them
|
||||
"""
|
||||
|
||||
from itertools import chain
|
||||
|
||||
from .int_tuple import *
|
||||
|
||||
|
||||
class LayoutBase:
|
||||
pass
|
||||
|
||||
|
||||
def is_layout(x):
|
||||
return isinstance(x, LayoutBase)
|
||||
|
||||
|
||||
class Layout(LayoutBase):
|
||||
def __init__(self, _shape, _stride=None):
|
||||
self.shape = _shape
|
||||
if _stride is None:
|
||||
self.stride = prefix_product(self.shape)
|
||||
else:
|
||||
self.stride = _stride
|
||||
|
||||
# operator ==
|
||||
def __eq__(self, other):
|
||||
return self.shape == other.shape and self.stride == other.stride
|
||||
|
||||
# operator len(L) (len [rank] like tuples)
|
||||
def __len__(self):
|
||||
if is_tuple(self.shape):
|
||||
return len(self.shape)
|
||||
else:
|
||||
return 1
|
||||
|
||||
# operator () (map coord to idx)
|
||||
def __call__(self, *args):
|
||||
"""
|
||||
Map a logical coordinate to a linear index (Coord has no Underscore slice operators)
|
||||
OR
|
||||
Slice the layout and return the sublayout (Coord has an Underscore slice op)
|
||||
|
||||
Follow the same behavior of `Layout::operator(Coord const&)` in cute C++
|
||||
"""
|
||||
if has_none(args):
|
||||
if len(args) == 1:
|
||||
return Layout(slice_(args[0], self.shape), slice_(args[0], self.stride))
|
||||
else:
|
||||
return Layout(slice_(args, self.shape), slice_(args, self.stride))
|
||||
else:
|
||||
if len(args) == 1:
|
||||
return crd2idx(args[0], self.shape, self.stride)
|
||||
else:
|
||||
return crd2idx(args, self.shape, self.stride)
|
||||
|
||||
# operator [] (get-i like tuples)
|
||||
def __getitem__(self, i):
|
||||
if is_tuple(self.shape):
|
||||
return Layout(self.shape[i], self.stride[i])
|
||||
else:
|
||||
assert i == 0
|
||||
return Layout(self.shape, self.stride)
|
||||
|
||||
# size(layout) Size of the domain
|
||||
def size(self):
|
||||
return product(self.shape)
|
||||
|
||||
# cosize(layout) Size of the codomain
|
||||
def cosize(self):
|
||||
return self(self.size() - 1) + 1
|
||||
|
||||
# print and str
|
||||
def __str__(self):
|
||||
return f"{self.shape}:{self.stride}"
|
||||
|
||||
# error msgs and representation
|
||||
def __repr__(self):
|
||||
return f"Layout({self.shape},{self.stride})"
|
||||
|
||||
|
||||
# Make Layout from a list of layouts (each layout it's own mode in the result)
|
||||
def make_layout(*layouts):
|
||||
if len(layouts) == 1 and not is_layout(layouts[0]):
|
||||
layouts = layouts[0]
|
||||
|
||||
shape, stride = zip(*((a.shape, a.stride) for a in layouts))
|
||||
return Layout(shape, stride)
|
||||
|
||||
|
||||
# Size of the domain
|
||||
def size(layout):
|
||||
if is_layout(layout):
|
||||
return layout.size()
|
||||
return product(layout)
|
||||
|
||||
|
||||
# Size of the codomain
|
||||
def cosize(layout):
|
||||
return layout.cosize()
|
||||
|
||||
|
||||
# Layout coalesce -- flatten and combine as many modes as possible while preserving the int-to-int function
|
||||
def coalesce(layout, profile=None):
|
||||
if is_tuple(profile):
|
||||
assert len(layout) >= len(profile)
|
||||
return make_layout(
|
||||
chain(
|
||||
(coalesce(layout[i], profile[i]) for i in range(0, len(profile))),
|
||||
(layout[i] for i in range(len(profile), len(layout))),
|
||||
)
|
||||
)
|
||||
|
||||
result_shape = [1]
|
||||
result_stride = [0]
|
||||
for shape, stride in zip(flatten(layout.shape), flatten(layout.stride)):
|
||||
# skip their shape-1s
|
||||
if shape == 1:
|
||||
continue
|
||||
# replace our shape-1 with anything
|
||||
elif result_shape[-1] == 1:
|
||||
result_shape[-1] = shape
|
||||
result_stride[-1] = stride
|
||||
# merge modes if the shape*stride match
|
||||
elif result_shape[-1] * result_stride[-1] == stride:
|
||||
result_shape[-1] = result_shape[-1] * shape
|
||||
# append a new mode
|
||||
else:
|
||||
result_shape.append(shape)
|
||||
result_stride.append(stride)
|
||||
|
||||
if len(result_shape) == 1:
|
||||
return Layout(result_shape[0], result_stride[0])
|
||||
else:
|
||||
return Layout(tuple(result_shape), tuple(result_stride))
|
||||
|
||||
|
||||
# Layout filter -- replace all stride-0 modes with size-1 and then coalesce to remove them
|
||||
def filter(layout, profile=None):
|
||||
if is_tuple(profile):
|
||||
assert len(layout) >= len(profile)
|
||||
return make_layout(
|
||||
chain(
|
||||
(filter(layout[i], profile[i]) for i in range(0, len(profile))),
|
||||
(layout[i] for i in range(len(profile), len(layout))),
|
||||
)
|
||||
)
|
||||
|
||||
result_shape = []
|
||||
result_stride = []
|
||||
for shape, stride in zip(flatten(layout.shape), flatten(layout.stride)):
|
||||
# skip their shape-1s and stride-0s
|
||||
if not (shape == 1 or stride == 0):
|
||||
result_shape.append(shape)
|
||||
result_stride.append(stride)
|
||||
|
||||
if len(result_shape) == 0:
|
||||
return Layout(1, 0)
|
||||
else:
|
||||
return coalesce(Layout(tuple(result_shape), tuple(result_stride)))
|
||||
|
||||
|
||||
# Layout composition
|
||||
# Use tuples-of-layouts to perform this operation by-mode and None as no-op
|
||||
def composition(layoutA, layoutB):
|
||||
if layoutB is None:
|
||||
return layoutA
|
||||
elif is_int(layoutB):
|
||||
return composition(layoutA, Layout(layoutB))
|
||||
elif is_tuple(layoutB):
|
||||
assert len(layoutA) >= len(layoutB)
|
||||
return make_layout(
|
||||
chain(
|
||||
(composition(layoutA[i], layoutB[i]) for i in range(0, len(layoutB))),
|
||||
(layoutA[i] for i in range(len(layoutB), len(layoutA))),
|
||||
)
|
||||
)
|
||||
elif is_tuple(layoutB.shape):
|
||||
return make_layout(composition(layoutA, layoutB_i) for layoutB_i in layoutB)
|
||||
|
||||
if layoutB.stride == 0:
|
||||
return Layout(layoutB.shape, 0)
|
||||
else:
|
||||
result_shape = []
|
||||
result_stride = []
|
||||
rest_shape = layoutB.shape
|
||||
rest_stride = layoutB.stride
|
||||
flat_A = coalesce(layoutA)
|
||||
for curr_shape, curr_stride in zip(
|
||||
flatten(flat_A.shape)[:-1], flatten(flat_A.stride)[:-1]
|
||||
):
|
||||
assert curr_shape % rest_stride == 0 or rest_stride % curr_shape == 0
|
||||
new_shape = min(max(1, curr_shape // rest_stride), rest_shape)
|
||||
|
||||
if new_shape != 1:
|
||||
result_shape.append(new_shape)
|
||||
result_stride.append(rest_stride * curr_stride)
|
||||
|
||||
rest_shape = rest_shape // new_shape
|
||||
rest_stride = -(
|
||||
-rest_stride // curr_shape
|
||||
) # Python exclusive impl: "//" is always floor div so == ceil_div(abs(rest_stride), curr_shape) * signum(rest_stride)
|
||||
|
||||
if rest_shape != 1 or len(result_shape) == 0:
|
||||
result_shape.append(rest_shape)
|
||||
result_stride.append(rest_stride * flatten(flat_A.stride)[-1])
|
||||
|
||||
if len(result_shape) == 1:
|
||||
return Layout(result_shape[0], result_stride[0])
|
||||
else:
|
||||
return Layout(tuple(result_shape), tuple(result_stride))
|
||||
|
||||
|
||||
# Layout complement
|
||||
def complement(layout, max_idx=1):
|
||||
if is_int(layout):
|
||||
return complement(Layout(layout))
|
||||
|
||||
result_shape = []
|
||||
result_stride = []
|
||||
current_idx = 1
|
||||
|
||||
sorted_DS = sorted(zip(flatten(layout.stride), flatten(layout.shape)))
|
||||
for stride, shape in sorted_DS:
|
||||
if stride == 0 or shape == 1:
|
||||
continue
|
||||
|
||||
in_bound = current_idx <= shape * stride
|
||||
# To support symbolic value which can't be evaluated now
|
||||
assert (type(in_bound) is not bool) or in_bound
|
||||
|
||||
result_shape.append(stride // current_idx)
|
||||
result_stride.append(current_idx)
|
||||
current_idx = shape * stride
|
||||
|
||||
result_shape.append((max_idx + current_idx - 1) // current_idx) # ceil_div
|
||||
result_stride.append(current_idx)
|
||||
|
||||
return coalesce(Layout(tuple(result_shape), tuple(result_stride)))
|
||||
|
||||
|
||||
# Layout right inverse
|
||||
def right_inverse(layout):
|
||||
if layout is None:
|
||||
return None
|
||||
elif is_int(layout):
|
||||
return Layout(layout)
|
||||
|
||||
result_shape = []
|
||||
result_stride = []
|
||||
current_idx = 1
|
||||
|
||||
flat_shape = flatten(layout.shape)
|
||||
flat_stride = flatten(layout.stride)
|
||||
sorted_DSA = sorted(zip(flat_stride, flat_shape, prefix_product(flat_shape)))
|
||||
for stride, shape, rstride in sorted_DSA:
|
||||
if shape == 1:
|
||||
continue
|
||||
if current_idx != stride:
|
||||
break
|
||||
|
||||
result_shape.append(shape)
|
||||
result_stride.append(rstride)
|
||||
current_idx = shape * stride
|
||||
|
||||
return coalesce(Layout(tuple(result_shape), tuple(result_stride)))
|
||||
|
||||
|
||||
# Layout left inverse
|
||||
def left_inverse(layout):
|
||||
if layout is None:
|
||||
return None
|
||||
elif is_int(layout):
|
||||
return Layout(layout)
|
||||
return right_inverse(make_layout(layout, complement(layout)))
|
||||
|
||||
|
||||
# Split a layout by the composition of B and the "rest"
|
||||
# Use tuples-of-layouts to perform this operation by-mode and None as no-op
|
||||
def logical_divide(layoutA, layoutB):
|
||||
if layoutB is None:
|
||||
return layoutA
|
||||
elif is_int(layoutB):
|
||||
return logical_divide(layoutA, Layout(layoutB))
|
||||
elif is_tuple(layoutB):
|
||||
assert len(layoutA) >= len(layoutB)
|
||||
return make_layout(
|
||||
chain(
|
||||
(
|
||||
logical_divide(layoutA[i], layoutB[i])
|
||||
for i in range(0, len(layoutB))
|
||||
),
|
||||
(layoutA[i] for i in range(len(layoutB), len(layoutA))),
|
||||
)
|
||||
)
|
||||
|
||||
return composition(
|
||||
layoutA, make_layout(layoutB, complement(layoutB, size(layoutA)))
|
||||
)
|
||||
|
||||
|
||||
# Reproduce a layoutA over a layoutB
|
||||
# Use tuples-of-layouts to perform this operation by-mode and None as no-op
|
||||
def logical_product(layoutA, layoutB):
|
||||
if layoutB is None:
|
||||
return layoutA
|
||||
elif is_int(layoutB):
|
||||
return logical_divide(layoutA, Layout(layoutB))
|
||||
elif is_tuple(layoutB):
|
||||
assert len(layoutA) >= len(layoutB)
|
||||
return make_layout(
|
||||
chain(
|
||||
(
|
||||
logical_product(layoutA[i], layoutB[i])
|
||||
for i in range(0, len(layoutB))
|
||||
),
|
||||
(layoutA[i] for i in range(len(layoutB), len(layoutA))),
|
||||
)
|
||||
)
|
||||
|
||||
return make_layout(
|
||||
layoutA,
|
||||
composition(complement(layoutA, size(layoutA) * cosize(layoutB)), layoutB),
|
||||
)
|
||||
|
||||
|
||||
# Gather the modes from a hierarchical logical_divide or logical_product
|
||||
def hier_unzip(splitter, layoutA, layoutB):
|
||||
if layoutB is None:
|
||||
return make_layout(Layout(1, 0), layoutA)
|
||||
elif is_tuple(layoutB):
|
||||
assert len(layoutA) >= len(layoutB)
|
||||
# A layout with shape ((A,a),(B,b),(C,c))
|
||||
split = make_layout(
|
||||
hier_unzip(splitter, layoutA[i], layoutB[i]) for i in range(0, len(layoutB))
|
||||
)
|
||||
# Gather to shape ((A,B,C,...),(a,b,c,...,y,z))
|
||||
return make_layout(
|
||||
make_layout(split[i][0] for i in range(0, len(layoutB))),
|
||||
make_layout(
|
||||
chain(
|
||||
(split[i][1] for i in range(0, len(layoutB))),
|
||||
(layoutA[i] for i in range(len(layoutB), len(layoutA))),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
# splitter must return a rank-2 layout
|
||||
return splitter(layoutA, layoutB)
|
||||
|
||||
|
||||
# Apply logical divide hierarchically and gather the split modes into two modes
|
||||
def zipped_divide(layoutA, layoutB):
|
||||
return hier_unzip(logical_divide, layoutA, layoutB)
|
||||
|
||||
|
||||
# Perform logical divide hierarchically and gather tiles (B-layouts) into a new mode
|
||||
def tiled_divide(layoutA, layoutB):
|
||||
result = zipped_divide(layoutA, layoutB)
|
||||
return make_layout([result[0]] + [result[1][i] for i in range(len(result[1]))])
|
||||
|
||||
|
||||
# Apply logical product hierarchically and gather the split modes into two modes
|
||||
def zipped_product(layoutA, layoutB):
|
||||
return hier_unzip(logical_product, layoutA, layoutB)
|
||||
|
||||
|
||||
# Perform logical product hierarchically and gather tiles (B-layouts) into a new mode
|
||||
def tiled_product(layoutA, layoutB):
|
||||
result = zipped_product(layoutA, layoutB)
|
||||
return make_layout([result[0]] + [result[1][i] for i in range(len(result[1]))])
|
||||
|
||||
|
||||
def slice_and_offset(crd: tuple, layout: Layout):
|
||||
return (
|
||||
Layout(slice_(crd, layout.shape), slice_(crd, layout.stride)),
|
||||
crd2idx(crd, layout.shape, layout.stride),
|
||||
)
|
||||
133
python/cutlass_api/cutlass_api/fusion/pycute/swizzle.py
Normal file
133
python/cutlass_api/cutlass_api/fusion/pycute/swizzle.py
Normal file
@@ -0,0 +1,133 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Methods for layout swizzling
|
||||
"""
|
||||
|
||||
from .layout import *
|
||||
|
||||
|
||||
def shiftr(a, s):
|
||||
return a >> s if s > 0 else shiftl(a, -s)
|
||||
|
||||
|
||||
def shiftl(a, s):
|
||||
return a << s if s > 0 else shiftr(a, -s)
|
||||
|
||||
|
||||
## A generic Swizzle functor
|
||||
# 0bxxxxxxxxxxxxxxxYYYxxxxxxxZZZxxxx
|
||||
# ^--^ Base is the number of least-sig bits to keep constant
|
||||
# ^-^ ^-^ Bits is the number of bits in the mask
|
||||
# ^---------^ Shift is the distance to shift the YYY mask
|
||||
# (pos shifts YYY to the right, neg shifts YYY to the left)
|
||||
#
|
||||
# e.g. Given
|
||||
# 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxZZxxx
|
||||
# the result is
|
||||
# 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxAAxxx where AA = ZZ xor YY
|
||||
#
|
||||
class Swizzle:
|
||||
def __init__(self, bits, base, shift):
|
||||
assert bits >= 0
|
||||
assert base >= 0
|
||||
assert abs(shift) >= bits
|
||||
self.bits = bits
|
||||
self.base = base
|
||||
self.shift = shift
|
||||
bit_msk = (1 << bits) - 1
|
||||
self.yyy_msk = bit_msk << (base + max(0, shift))
|
||||
self.zzz_msk = bit_msk << (base - min(0, shift))
|
||||
|
||||
# operator () (transform integer)
|
||||
def __call__(self, offset):
|
||||
return offset ^ shiftr(offset & self.yyy_msk, self.shift)
|
||||
|
||||
# Size of the domain
|
||||
def size(self):
|
||||
return 1 << (self.bits + self.base + abs(self.shift))
|
||||
|
||||
# Size of the codomain
|
||||
def cosize(self):
|
||||
return self.size()
|
||||
|
||||
# print and str
|
||||
def __str__(self):
|
||||
return f"SW_{self.bits}_{self.base}_{self.shift}"
|
||||
|
||||
# error msgs and representation
|
||||
def __repr__(self):
|
||||
return f"Swizzle({self.bits},{self.base},{self.shift})"
|
||||
|
||||
|
||||
class ComposedLayout(LayoutBase):
|
||||
def __init__(self, layoutB, offset, layoutA):
|
||||
self.layoutB = layoutB
|
||||
self.offset = offset
|
||||
self.layoutA = layoutA
|
||||
|
||||
# operator ==
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
self.layoutB == other.layoutB
|
||||
and self.offset == other.offset
|
||||
and self.layoutA == other.layoutA
|
||||
)
|
||||
|
||||
# operator len(L) (len [rank] like tuples)
|
||||
def __len__(self):
|
||||
return len(self.layoutA)
|
||||
|
||||
# operator () (map coord to idx)
|
||||
def __call__(self, *args):
|
||||
return self.layoutB(self.offset + self.layoutA(*args))
|
||||
|
||||
# operator [] (get-i like tuples)
|
||||
def __getitem__(self, i):
|
||||
return ComposedLayout(self.layoutB, self.offset, self.layoutA[i])
|
||||
|
||||
# size(layout) Size of the domain
|
||||
def size(self):
|
||||
return size(self.layoutA)
|
||||
|
||||
# cosize(layout) Size of the codomain
|
||||
def cosize(self):
|
||||
return cosize(self.layoutB)
|
||||
|
||||
# print and str
|
||||
def __str__(self):
|
||||
return f"{self.layoutB} o {self.offset} o {self.layoutA}"
|
||||
|
||||
# error msgs and representation
|
||||
def __repr__(self):
|
||||
return f"ComposedLayout({repr(self.layoutB)},{repr(self.offset)},{repr(self.layoutA)})"
|
||||
42
python/cutlass_api/cutlass_api/fusion/pycute/typing.py
Normal file
42
python/cutlass_api/cutlass_api/fusion/pycute/typing.py
Normal file
@@ -0,0 +1,42 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from abc import ABC
|
||||
|
||||
|
||||
class Integer(ABC):
|
||||
@classmethod
|
||||
def __subclasshook__(cls, c):
|
||||
if c in [bool, float]:
|
||||
return False
|
||||
|
||||
return issubclass(c, int)
|
||||
196
python/cutlass_api/cutlass_api/kernel.py
Normal file
196
python/cutlass_api/cutlass_api/kernel.py
Normal file
@@ -0,0 +1,196 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from typing import final
|
||||
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
from cutlass_api.arguments import EpilogueArguments, RuntimeArguments
|
||||
from cutlass_api.artifact import CompiledArtifact
|
||||
from cutlass_api.metadata import KernelMetadata
|
||||
from cutlass_api.status import Status
|
||||
|
||||
|
||||
class Kernel(ABC):
|
||||
"""
|
||||
Base class for all kernels to be implemented in providers
|
||||
"""
|
||||
|
||||
@final
|
||||
def supports(self, args: RuntimeArguments) -> Status:
|
||||
"""
|
||||
Returns whether the kernel can be compiled or run with the provided arguments
|
||||
|
||||
:param args: arguments with which the kernel is to be compiled or run
|
||||
:type args: RuntimeArguments
|
||||
|
||||
:return: Status indicating whether the kernel can be compiled or run with the provided arguments
|
||||
:rtype: Status
|
||||
"""
|
||||
# Metadata should capture most common checks
|
||||
if not (status := self.metadata.supports(args)):
|
||||
return status
|
||||
# Additional checks can be implemented in the subclass
|
||||
return self._supports(args)
|
||||
|
||||
@abstractmethod
|
||||
def compile(
|
||||
self, args: RuntimeArguments, cc: int | None = None
|
||||
) -> CompiledArtifact:
|
||||
"""
|
||||
Compiles the kernel.
|
||||
|
||||
For just-in-time compilation, the CompiledArtifact can be used to execute the kernel using run().
|
||||
For ahead-of-time compilation, the CompiledArtifact can be saved and restored later to use with run().
|
||||
:param args: Arguments to compile the kernel with. These need not be the same arguments as those passed to the run() method.
|
||||
:type args: RuntimeArguments
|
||||
:param cc: Compute capability of device for which the kernel is to be compiled. For example, if running on H100, this should be set to 90.
|
||||
:type cc: int
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@final
|
||||
def run(
|
||||
self,
|
||||
args: RuntimeArguments,
|
||||
compiled_artifact: CompiledArtifact | None = None,
|
||||
stream=None,
|
||||
workspace=None,
|
||||
assume_supported_args: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Executes the kernel end-to-end with provided arguments -- check args are supported, compile if needed, create a default stream if needed, and launch the kernel.
|
||||
|
||||
:param args: Arguments to run the kernel with
|
||||
:type args: RuntimeArguments
|
||||
:param compiled_artifact: Compiled kernel object returned from the compile() method above. If None, the kernel will first be compiled.
|
||||
:type compiled_artifact: CompiledArtifact | None
|
||||
:param stream: Stream to execute the kernel on. If not provided, the default/null stream cuda.CUstream(0) is used.
|
||||
:type stream: cuda.CUstream, torch.cuda.Stream, or other stream-like object, or None
|
||||
:param workspace: Allocation of workspace at least as large as the workspace size returned from the get_workspace_size() method. If the kernel does not require workspace, this can be None.
|
||||
If a workspace of inappropriate size is provided, the behavior is undefined and the kernel may crash.
|
||||
:type workspace: any | None
|
||||
:param assume_supported_args: By default, kernel.supports(args) is called to check if the arguments are supported. If True, this check is skipped.
|
||||
:type assume_supported_args: bool
|
||||
"""
|
||||
if not assume_supported_args and not (supports := self.supports(args)):
|
||||
raise ValueError(
|
||||
f"Kernel does not support the provided arguments: {supports.error}"
|
||||
)
|
||||
|
||||
compiled_artifact = (
|
||||
self.compile(args) if not compiled_artifact else compiled_artifact
|
||||
)
|
||||
if not stream:
|
||||
stream = cuda.CUstream(0)
|
||||
return self._run(args, compiled_artifact, stream, workspace)
|
||||
|
||||
def get_workspace_size(self, args: RuntimeArguments) -> int:
|
||||
"""
|
||||
Returns the size of the workspace required by the kernel in bytes.
|
||||
|
||||
:param args: arguments of the kernel
|
||||
:type args: RuntimeArguments
|
||||
|
||||
:return: size of the workspace required by the kernel in bytes
|
||||
:rtype: int
|
||||
"""
|
||||
return 0
|
||||
|
||||
def initialize_workspace(self, args: RuntimeArguments, workspace) -> None:
|
||||
"""
|
||||
Initializes the workspace for the kernel.
|
||||
:param args: Arguments to initialize the workspace for
|
||||
:type args: RuntimeArguments
|
||||
:param workspace: Workspace to initialize
|
||||
:type workspace: any
|
||||
"""
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def generate_kernels(
|
||||
metadata_filter: Callable[[KernelMetadata], bool],
|
||||
epilogue_args: EpilogueArguments = None,
|
||||
cc: int = None,
|
||||
) -> list["Kernel"]:
|
||||
"""
|
||||
Populates the `kernels` list with all supported kernel configurations
|
||||
for the given compute capability and arguments.
|
||||
:param metadata_filter: Optional function that takes KernelMetadata and returns True for kernels to include
|
||||
:type metadata_filter: Callable[[KernelMetadata], bool]
|
||||
:param epilogue_args: Optional arguments to pass to kernel epilogue
|
||||
:type epilogue_args: EpilogueArguments | None
|
||||
:param cc: Optional compute capability to target; e.g., 90 for H100
|
||||
:type cc: int | None
|
||||
|
||||
:return: list of all supported kernel configurations
|
||||
:rtype: list[Kernel]
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _supports(self, args: RuntimeArguments) -> Status:
|
||||
"""
|
||||
Classes may override this method to perform any additional checks that are not captured in metadata.
|
||||
Ideally, all/most checks should be captured in the metadata.
|
||||
By default, no such checks are performed and the method trivially returns Status.success().
|
||||
|
||||
:param args: Arguments to check support for
|
||||
:type args: RuntimeArguments
|
||||
|
||||
:return: Status indicating success, or reason why the kernel does not support the provided arguments
|
||||
:rtype: Status
|
||||
"""
|
||||
return Status.success()
|
||||
|
||||
@abstractmethod
|
||||
def _run(
|
||||
self,
|
||||
args: RuntimeArguments,
|
||||
compiled_artifact: CompiledArtifact,
|
||||
stream,
|
||||
workspace=None,
|
||||
) -> None:
|
||||
"""
|
||||
A miniminal version of the run() method that assumes args are supported, a valid compiled artifact and a valid stream are provided.
|
||||
It is intended to be overridden by subclasses to implement the actual kernel execution, and be a minimal wrapper to launching the kernel.
|
||||
|
||||
:param args: Arguments to run the kernel with. It is assumed that kernel.supports(args) is True.
|
||||
:type args: RuntimeArguments
|
||||
:param compiled_artifact: Compiled kernel object returned from the compile() method
|
||||
:type compiled_artifact: CompiledArtifact
|
||||
:param stream: Stream to execute the kernel on.
|
||||
:type stream: cuda.CUstream, torch.cuda.Stream, or other stream-like object
|
||||
:param workspace: Allocation of workspace at least as large as the workspace size returned from the get_workspace_size() method.
|
||||
If the kernel does not require workspace, this can be None.
|
||||
If a workspace of inappropriate size is provided, the behavior is undefined and the kernel may crash.
|
||||
:type workspace: any
|
||||
"""
|
||||
raise NotImplementedError
|
||||
138
python/cutlass_api/cutlass_api/manifest.py
Normal file
138
python/cutlass_api/cutlass_api/manifest.py
Normal file
@@ -0,0 +1,138 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
|
||||
from cutlass_api.arguments import RuntimeArguments
|
||||
from cutlass_api.kernel import Kernel
|
||||
from cutlass_api.metadata import KernelMetadata
|
||||
from cutlass_api.providers import available_providers
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Manifest:
|
||||
def __init__(self) -> None:
|
||||
self._candidate_kernels = []
|
||||
|
||||
@staticmethod
|
||||
def get_kernels(
|
||||
args: RuntimeArguments = None,
|
||||
metadata_filter: Callable[[KernelMetadata], bool] | None = None,
|
||||
cc: int = None,
|
||||
providers: list[str] = None,
|
||||
) -> list[Kernel]:
|
||||
"""
|
||||
Get the kernels that match the given arguments, metadata filter, and compute capability.
|
||||
|
||||
:param args: the arguments of the kernel
|
||||
:type args: RuntimeArguments
|
||||
:param metadata_filter: a boolean function that takes in KernelMetadata and returns whether
|
||||
a Kernel from this metadata should be included
|
||||
:type metadata_filter: Callable[[KernelMetadata], bool]
|
||||
:param cc: the compute capability
|
||||
:type cc: int
|
||||
:param providers: the providers to use
|
||||
:type providers: list[str]
|
||||
|
||||
:return: the kernels that match the given arguments, metadata filter, and compute capability
|
||||
:rtype: list[Kernel]
|
||||
"""
|
||||
# Setup providers to use
|
||||
providers_to_use = []
|
||||
if providers is None:
|
||||
providers_to_use = list(available_providers.values())
|
||||
else:
|
||||
for provider in providers:
|
||||
if provider not in available_providers:
|
||||
raise ValueError(f"Provider {provider} is not available")
|
||||
providers_to_use.append(available_providers[provider])
|
||||
|
||||
# Setup filter function
|
||||
filter_fn = (lambda x: True) if metadata_filter is None else metadata_filter
|
||||
if args is None:
|
||||
full_filter_fn = filter_fn
|
||||
else:
|
||||
|
||||
def full_filter_fn(metadata: KernelMetadata) -> bool:
|
||||
if not filter_fn(metadata):
|
||||
return False
|
||||
if not (supports := metadata.supports(args)):
|
||||
_logger.debug(
|
||||
f"Rejecting kernel {metadata.kernel_name}. Reason: {supports.error}"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
epilogue_args = None if args is None else args.epilogue
|
||||
kernels = [
|
||||
k
|
||||
# Generate kernels from all providers
|
||||
for provider in providers_to_use
|
||||
# Filter kernels by metadata
|
||||
for k in provider.generate_kernels(
|
||||
full_filter_fn, cc=cc, epilogue_args=epilogue_args
|
||||
)
|
||||
# Do any additional checks on args that are not captured by metadata
|
||||
if not args or k._supports(args)
|
||||
]
|
||||
return kernels
|
||||
|
||||
def add_kernels(
|
||||
self,
|
||||
args: RuntimeArguments = None,
|
||||
metadata_filter: Callable[[KernelMetadata], bool] | None = None,
|
||||
cc: int = None,
|
||||
providers: list[str] = None,
|
||||
) -> list[Kernel]:
|
||||
"""
|
||||
Get the kernels that match the given arguments, metadata filter, and compute capability,
|
||||
and add them to the Manifest's set of discovered kernels.
|
||||
|
||||
:param args: the arguments of the kernel
|
||||
:type args: RuntimeArguments
|
||||
:param metadata_filter: a boolean function that takes in KernelMetadata and returns whether
|
||||
a Kernel from this metadata should be included
|
||||
:type metadata_filter: Callable[[KernelMetadata], bool]
|
||||
:param cc: the compute capability
|
||||
:type cc: int
|
||||
:param providers: the providers to use
|
||||
:type providers: list[str]
|
||||
|
||||
:return: the kernels that match the given arguments, metadata filter, and compute capability
|
||||
:rtype: list[Kernel]
|
||||
"""
|
||||
matched_kernels = Manifest.get_kernels(args, metadata_filter, cc, providers)
|
||||
self._candidate_kernels.extend(matched_kernels)
|
||||
return matched_kernels
|
||||
|
||||
@property
|
||||
def kernels(self) -> list[Kernel]:
|
||||
return self._candidate_kernels
|
||||
504
python/cutlass_api/cutlass_api/metadata.py
Normal file
504
python/cutlass_api/cutlass_api/metadata.py
Normal file
@@ -0,0 +1,504 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Self
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import cutlass
|
||||
|
||||
import cutlass.cute as cute
|
||||
|
||||
from cutlass_api.arguments import (
|
||||
ElementwiseArguments,
|
||||
EpilogueArguments,
|
||||
GemmArguments,
|
||||
RuntimeArguments,
|
||||
)
|
||||
from cutlass_api.status import Status
|
||||
from cutlass_api.utils import TensorWrapper
|
||||
|
||||
|
||||
def _convert_stride(shape: tuple[int, ...], stride: tuple[int, ...]) -> tuple[int, ...]:
|
||||
"""
|
||||
Zeros out modes of stride that have a shape of 1.
|
||||
|
||||
:param shape: the shape of the tensor
|
||||
:type shape: tuple[int, ...]
|
||||
:param stride: the stride of the tensor
|
||||
:type stride: tuple[int, ...]
|
||||
|
||||
:return: the converted stride
|
||||
:rtype: tuple[int, ...]
|
||||
"""
|
||||
new_stride = []
|
||||
for i in range(len(shape)):
|
||||
if shape[i] == 1:
|
||||
new_stride.append(0)
|
||||
else:
|
||||
new_stride.append(stride[i])
|
||||
return new_stride
|
||||
|
||||
|
||||
def _get_max_pow2_alignment(
|
||||
shape: tuple[int, ...], stride: tuple[int, ...], dtype: cutlass.Numeric
|
||||
) -> int:
|
||||
"""
|
||||
Get the maximum power of 2 alignment for a given data type
|
||||
|
||||
:param shape: the shape of the tensor
|
||||
:type shape: tuple[int, ...]
|
||||
:param stride: the stride of the tensor
|
||||
:type stride: tuple[int, ...]
|
||||
:param dtype: the data type
|
||||
:type dtype: cutlass.Numeric
|
||||
|
||||
:return: the maximum power of 2 alignment
|
||||
:rtype: int
|
||||
"""
|
||||
if 1 not in stride:
|
||||
return 1
|
||||
|
||||
major_mode_idx = stride.index(1)
|
||||
num_major_elements = shape[major_mode_idx]
|
||||
for alignment in [128, 64, 32, 16, 8, 4, 2]:
|
||||
num_contiguous_elements = alignment * 8 // dtype.width
|
||||
if num_major_elements % num_contiguous_elements == 0:
|
||||
return alignment
|
||||
return 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class TensorAttributes:
|
||||
"""
|
||||
Description of a single tensor. This includes the data type, stride, and alignment.
|
||||
|
||||
:param dtype: The data type of the tensor.
|
||||
:type dtype: cutlass.Numeric
|
||||
:param stride: The stride of the tensor.
|
||||
:type stride: tuple[int, ...]
|
||||
:param alignment: The alignment of the tensor
|
||||
:type alignment: int
|
||||
"""
|
||||
|
||||
dtype: cutlass.Numeric # F32, F16, etc.
|
||||
stride: tuple[int, ...]
|
||||
alignment: int
|
||||
|
||||
def supports(self, operand: TensorWrapper | Self) -> Status:
|
||||
"""
|
||||
Checks whether the provided args satisfy the properties described by
|
||||
these TensorAttributes.
|
||||
|
||||
:param operand: The operand to check support for.
|
||||
:type operand: TensorWrapper | Self
|
||||
|
||||
:return: Whether the provided operand satisfies the properties described by
|
||||
these TensorAttributes.
|
||||
:rtype: Status
|
||||
"""
|
||||
if isinstance(operand, TensorWrapper):
|
||||
if operand.element_type != self.dtype:
|
||||
return Status.fail(
|
||||
f"Expected element type {self.dtype}, got {operand.element_type}"
|
||||
)
|
||||
elif operand.dtype != self.dtype:
|
||||
return Status.fail(
|
||||
f"Expected element type {self.dtype}, got {operand.dtype}"
|
||||
)
|
||||
|
||||
# Normalize the operand stride to have 0's where operand.shape is 1
|
||||
normalized_operand_stride = _convert_stride(operand.shape, operand.stride)
|
||||
|
||||
# If the operand is batched, the expected stride is the same as the self.stride
|
||||
if len(self.stride) == len(normalized_operand_stride):
|
||||
expected_stride = self.stride
|
||||
# We can try if strides match without the batch mode
|
||||
elif len(self.stride) - 1 == len(normalized_operand_stride):
|
||||
expected_stride = self.stride[1:]
|
||||
else:
|
||||
return Status.fail(
|
||||
f"Expected tensor with strides of rank {len(self.stride)} (batched) or {len(self.stride) - 1} (unbatched), got {len(normalized_operand_stride)} ({normalized_operand_stride})"
|
||||
)
|
||||
|
||||
# Strides are considered compatible if:
|
||||
# 1. They have the same rank (checked above)
|
||||
# 2. The leading mode is the same (or both are all 0)
|
||||
all_zeros = all(x == 0 for x in expected_stride) and all(
|
||||
x == 0 for x in normalized_operand_stride
|
||||
)
|
||||
|
||||
# When setting stride from args, any modes of stride 1 and shape 1
|
||||
# are changed to have stride 0. Thus, there will only be one mode
|
||||
# with stride 1.
|
||||
if not all_zeros and normalized_operand_stride[expected_stride.index(1)] != 1:
|
||||
return Status.fail(
|
||||
f"Expected stride[{expected_stride.index(1)}] to be 1, got {normalized_operand_stride[expected_stride.index(1)]} (strides: {normalized_operand_stride})"
|
||||
)
|
||||
|
||||
# Alignment of operand should be divisible by this metadata's alignment
|
||||
if isinstance(operand, TensorWrapper):
|
||||
operand_alignment = _get_max_pow2_alignment(
|
||||
operand.shape, normalized_operand_stride, operand.element_type
|
||||
)
|
||||
else:
|
||||
operand_alignment = operand.alignment
|
||||
|
||||
if operand_alignment % self.alignment != 0:
|
||||
return Status.fail(
|
||||
f"Expected operand alignment {operand_alignment} (strides: {normalized_operand_stride}) to be a multiple of {self.alignment}"
|
||||
)
|
||||
|
||||
return Status.success()
|
||||
|
||||
@staticmethod
|
||||
def from_tensor(tensor) -> Self:
|
||||
"""
|
||||
Create a TensorAttributes from a tensor.
|
||||
|
||||
:param tensor: The tensor to create a TensorAttributes from.
|
||||
:type tensor: cute.Tensor
|
||||
|
||||
:return: The TensorAttributes corresponding to the provided tensor.
|
||||
:rtype: TensorAttributes
|
||||
"""
|
||||
stride = _convert_stride(tensor.shape, tensor.stride)
|
||||
max_alignment = _get_max_pow2_alignment(
|
||||
tensor.shape, stride, tensor.element_type
|
||||
)
|
||||
return TensorAttributes(
|
||||
dtype=tensor.element_type, stride=stride, alignment=max_alignment
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OperandsMetadata(ABC):
|
||||
"""
|
||||
Base metadata class for descriptions of operands (e.g., GEMM A, B, out).
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def supports(self, args: RuntimeArguments) -> Status:
|
||||
"""
|
||||
Checks whether the provided args satisfy the properties described by
|
||||
the operands in this metadata.
|
||||
|
||||
:param args: The arguments to check support for.
|
||||
:type args: RuntimeArguments
|
||||
|
||||
:return: Whether the provided args satisfy the properties described by
|
||||
the operands in this metadata.
|
||||
:rtype: Status
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class DesignMetadata(ABC):
|
||||
"""
|
||||
Base metadata class for descriptions of design parameters for an operation
|
||||
(e.g., tile shape, cluster shape, etc.).
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def supports(self, args: RuntimeArguments) -> Status:
|
||||
"""
|
||||
Checks whether the provided args satisfy the properties described by
|
||||
the design in this metadata.
|
||||
|
||||
:param args: The arguments to check support for.
|
||||
:type args: RuntimeArguments
|
||||
|
||||
:return: Whether the provided args satisfy the properties described by
|
||||
the design in this metadata.
|
||||
:rtype: Status
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class BLASDesignMetadata(DesignMetadata):
|
||||
"""
|
||||
Design metadata for a basic-linear algebra subprogram (BLAS) operation.
|
||||
|
||||
These include fields for tiling-related parameters (e.g., tile shape and cluster shape).
|
||||
"""
|
||||
|
||||
tile_shape: tuple[int, ...]
|
||||
cluster_shape: tuple[int, ...]
|
||||
|
||||
def supports(self, args: RuntimeArguments) -> Status:
|
||||
"""
|
||||
Checks whether the provided args satisfy the properties described by
|
||||
the design in this metadata.
|
||||
|
||||
:param args: The arguments to check support for.
|
||||
:type args: RuntimeArguments
|
||||
|
||||
:return: Whether the provided args satisfy the properties described by
|
||||
the design in this metadata.
|
||||
:rtype: Status
|
||||
"""
|
||||
if args.performance is not None:
|
||||
return Status.fail(
|
||||
"BLASDesignMetadata does not yet support performance controls"
|
||||
)
|
||||
return Status.success()
|
||||
|
||||
|
||||
@dataclass
|
||||
class Sm100DesignMetadata(BLASDesignMetadata):
|
||||
"""
|
||||
Design metadata for kernels in the SM100 architecture family.
|
||||
"""
|
||||
|
||||
# Whether to use a 2CTA MMA instruction
|
||||
use_2cta_mma: bool
|
||||
|
||||
# Whether to use TMA to store the results of the operation
|
||||
use_tma_store: bool
|
||||
|
||||
@dataclass
|
||||
class GemmOperandsMetadata(OperandsMetadata):
|
||||
"""
|
||||
Metadata for the operands of a GEMM operation.
|
||||
|
||||
:param A: Metadata for the input tensor A.
|
||||
:type A: TensorAttributes
|
||||
:param B: Metadata for the input tensor B.
|
||||
:type B: TensorAttributes
|
||||
:param out: Metadata for the output tensor.
|
||||
:type out: TensorAttributes
|
||||
:param accumulator_type: The data type of the accumulator tensor.
|
||||
:type accumulator_type: cutlass.Numeric
|
||||
"""
|
||||
|
||||
A: TensorAttributes
|
||||
B: TensorAttributes
|
||||
out: TensorAttributes
|
||||
accumulator_type: cutlass.Numeric
|
||||
|
||||
def supports(self, other: GemmArguments | Self) -> Status:
|
||||
"""
|
||||
Checks whether the provided args satisfy the properties described by
|
||||
the operands in this metadata.
|
||||
|
||||
:param other: The arguments to check support for.
|
||||
:type other: GemmArguments | Self
|
||||
|
||||
:return: Whether the provided args satisfy the properties described by
|
||||
the operands in this metadata.
|
||||
:rtype: Status
|
||||
"""
|
||||
if isinstance(other, RuntimeArguments) and not isinstance(other, GemmArguments):
|
||||
return Status.fail(f"Expected GemmArguments, got {type(other)}")
|
||||
|
||||
if not (status := self.A.supports(other.A)):
|
||||
return Status.fail(f"Operand `A` is unsupported: {status.error}")
|
||||
|
||||
if not (status := self.B.supports(other.B)):
|
||||
return Status.fail(f"Operand `B` is unsupported: {status.error}")
|
||||
|
||||
if not (status := self.out.supports(other.out)):
|
||||
return Status.fail(f"Operand `out` is unsupported: {status.error}")
|
||||
|
||||
if self.accumulator_type != other.accumulator_type:
|
||||
return Status.fail(
|
||||
f"Expected accumulator type {self.accumulator_type}, got {other.accumulator_type}"
|
||||
)
|
||||
|
||||
return Status.success()
|
||||
|
||||
@staticmethod
|
||||
def from_args(args: GemmArguments) -> Self:
|
||||
"""
|
||||
Create a GemmOperandsMetadata from a GemmArguments.
|
||||
|
||||
:param args: The GemmArguments to create a GemmOperandsMetadata from.
|
||||
:type args: GemmArguments
|
||||
|
||||
:return: The GemmOperandsMetadata corresponding to the provided GemmArguments.
|
||||
:rtype: GemmOperandsMetadata
|
||||
"""
|
||||
return GemmOperandsMetadata(
|
||||
A=TensorAttributes.from_tensor(args.A),
|
||||
B=TensorAttributes.from_tensor(args.B),
|
||||
out=TensorAttributes.from_tensor(args.out),
|
||||
accumulator_type=args.accumulator_type,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ElementwiseOperandsMetadata(OperandsMetadata):
|
||||
"""
|
||||
Metadata for the operands of an elementwise operation.
|
||||
|
||||
:param A: Metadata for the input tensor A.
|
||||
:type A: TensorAttributes
|
||||
:param B: Metadata for the input tensor B.
|
||||
:type B: TensorAttributes
|
||||
:param out: Metadata for the output tensor.
|
||||
:type out: TensorAttributes
|
||||
"""
|
||||
|
||||
A: TensorAttributes
|
||||
B: TensorAttributes
|
||||
out: TensorAttributes
|
||||
|
||||
def supports(self, other: ElementwiseArguments | Self) -> Status:
|
||||
"""
|
||||
Checks whether the provided args satisfy the properties described by
|
||||
the operands in this metadata.
|
||||
|
||||
:param other: The arguments to check support for.
|
||||
:type other: ElementwiseArguments | Self
|
||||
|
||||
:return: Whether the provided args satisfy the properties described by
|
||||
the operands in this metadata.
|
||||
:rtype: Status
|
||||
"""
|
||||
if isinstance(other, RuntimeArguments) and not isinstance(
|
||||
other, ElementwiseArguments
|
||||
):
|
||||
return Status.fail(f"Expected ElementwiseArguments, got {type(other)}")
|
||||
|
||||
if not (status := self.A.supports(other.A)):
|
||||
return Status.fail(f"Operand `A` is unsupported: {status.error}")
|
||||
|
||||
if not (status := self.B.supports(other.B)):
|
||||
return Status.fail(f"Operand `B` is unsupported: {status.error}")
|
||||
|
||||
if not (status := self.out.supports(other.out)):
|
||||
return Status.fail(f"Operand `out` is unsupported: {status.error}")
|
||||
|
||||
return Status.success()
|
||||
|
||||
@staticmethod
|
||||
def from_args(args: ElementwiseArguments) -> Self:
|
||||
"""
|
||||
Create a ElementwiseOperandsMetadata from a ElementwiseArguments.
|
||||
|
||||
:param args: The ElementwiseArguments to create a ElementwiseOperandsMetadata from.
|
||||
:type args: ElementwiseArguments
|
||||
|
||||
:return: The ElementwiseOperandsMetadata corresponding to the provided ElementwiseArguments.
|
||||
:rtype: ElementwiseOperandsMetadata
|
||||
"""
|
||||
return ElementwiseOperandsMetadata(
|
||||
A=TensorAttributes.from_tensor(args.A),
|
||||
B=TensorAttributes.from_tensor(args.B),
|
||||
out=TensorAttributes.from_tensor(args.out),
|
||||
)
|
||||
|
||||
|
||||
class EpilogueMetadata:
|
||||
def __init__(self, epilogue_args: EpilogueArguments):
|
||||
self.traced_epilogue = epilogue_args.traced_epilogue
|
||||
self.tensors = epilogue_args.tensors
|
||||
self.epilogue_fn = epilogue_args.epilogue_fn
|
||||
|
||||
@staticmethod
|
||||
def from_args(args: EpilogueArguments) -> Self:
|
||||
# For now, EpilogueArguments and EpilogueMetadata are the same
|
||||
return EpilogueMetadata(args)
|
||||
|
||||
@property
|
||||
def parameters(self) -> list[cute.Tensor | cutlass.Numeric]:
|
||||
return list(self.tensors.values())
|
||||
|
||||
@property
|
||||
def parameter_names(self) -> list[str]:
|
||||
return list(self.tensors.keys())
|
||||
|
||||
def supports(self, args: RuntimeArguments) -> Status:
|
||||
return Status.success()
|
||||
|
||||
|
||||
@dataclass
|
||||
class KernelMetadata:
|
||||
"""
|
||||
Metadata describing the operands and design of a kernel.
|
||||
|
||||
In addition to information about operands, this metadata also contains
|
||||
the properties of the kernel that are independent of the specific arguments
|
||||
that are passed to the kernel (e.g., tile shape, cluster shape, etc.).
|
||||
:param kernel_name: The name of the kernel.
|
||||
:type kernel_name: str
|
||||
:param kernel_class: The class of the kernel.
|
||||
:type kernel_class: type["Kernel"]
|
||||
:param min_cc: The minimum compute capability of the kernel.
|
||||
:type min_cc: int
|
||||
:param operands: Metadata for the operands of the kernel.
|
||||
:type operands: OperandsT
|
||||
:param design: Metadata for the design of the kernel.
|
||||
:type design: DesignT | None
|
||||
:param epilogue: Metadata for the epilogue of the kernel.
|
||||
:type epilogue: EpilogueT | None
|
||||
"""
|
||||
|
||||
kernel_name: str
|
||||
kernel_class: type["Kernel"]
|
||||
min_cc: int
|
||||
operands: OperandsMetadata
|
||||
design: DesignMetadata | None = None
|
||||
epilogue: EpilogueMetadata | None = None
|
||||
|
||||
def supports(self, args: RuntimeArguments) -> Status:
|
||||
"""
|
||||
Checks whether the provided args satisfy the properties described by
|
||||
the operands, design, and epilogue metadata.
|
||||
|
||||
:param args: The arguments to check support for.
|
||||
:type args: RuntimeArguments
|
||||
|
||||
:return: Whether the provided args satisfy the properties described by
|
||||
the operands, design, and epilogue metadata.
|
||||
:rtype: Status
|
||||
"""
|
||||
|
||||
def supports_or_none(member, corresponding_arg, name: str) -> Status:
|
||||
# If metadata is absent, accept only when the corresponding argument is also absent.
|
||||
if member is None:
|
||||
if corresponding_arg is None:
|
||||
return Status.success()
|
||||
return Status.fail(
|
||||
f"{name} metadata is absent but argument is provided"
|
||||
)
|
||||
return member.supports(args)
|
||||
|
||||
if not (status := self.operands.supports(args)):
|
||||
return status
|
||||
|
||||
if not (status := supports_or_none(self.design, args.performance, "design")):
|
||||
return status
|
||||
|
||||
if not (status := supports_or_none(self.epilogue, args.epilogue, "epilogue")):
|
||||
return status
|
||||
|
||||
return Status.success()
|
||||
62
python/cutlass_api/cutlass_api/providers/__init__.py
Normal file
62
python/cutlass_api/cutlass_api/providers/__init__.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
from cutlass_api.providers.provider import ProviderBase
|
||||
|
||||
available_providers: dict[str, type[ProviderBase]] = {}
|
||||
|
||||
|
||||
def register_provider(name: str) -> Callable[[type[ProviderBase]], type[ProviderBase]]:
|
||||
"""
|
||||
Decorator used to register a provider class with the given name.
|
||||
|
||||
:param name: the name of the provider
|
||||
:type name: str
|
||||
|
||||
:return: the wrapper function
|
||||
:rtype: Callable[[Type[ProviderBase]], Type[ProviderBase]]
|
||||
"""
|
||||
|
||||
def wrapper(provider_class: type[ProviderBase]) -> type[ProviderBase]:
|
||||
"""
|
||||
Wrapper function to register a provider class with the given name.
|
||||
|
||||
:param provider_class: the provider class to register
|
||||
:type provider_class: Type[ProviderBase]
|
||||
"""
|
||||
global available_providers
|
||||
available_providers[name] = provider_class
|
||||
return provider_class
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
# Import for side effects (provider registration)
|
||||
import cutlass_api.providers.cutedsl # noqa: F401, E402
|
||||
88
python/cutlass_api/cutlass_api/providers/cutedsl/__init__.py
Normal file
88
python/cutlass_api/cutlass_api/providers/cutedsl/__init__.py
Normal file
@@ -0,0 +1,88 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
from collections.abc import Callable
|
||||
import logging
|
||||
from typing import Type
|
||||
|
||||
from cutlass_api.arguments import EpilogueArguments
|
||||
from cutlass_api.kernel import Kernel
|
||||
from cutlass_api.metadata import KernelMetadata
|
||||
from cutlass_api.providers import register_provider
|
||||
|
||||
|
||||
def try_import() -> bool:
|
||||
import importlib.util
|
||||
|
||||
if importlib.util.find_spec("cutlass") is None:
|
||||
logging.warning(
|
||||
"'cutlass' could not be imported. The cutedsl provider will not be available."
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
available = try_import()
|
||||
|
||||
|
||||
if available:
|
||||
|
||||
@register_provider("cutedsl")
|
||||
class CuTeDSLProvider:
|
||||
# Kernel classes currently registered with this provider
|
||||
_kernel_classes = []
|
||||
|
||||
@classmethod
|
||||
def generate_kernels(
|
||||
cls,
|
||||
metadata_filter: Callable[[KernelMetadata], bool] | None,
|
||||
epilogue_args: EpilogueArguments = None,
|
||||
cc: int = None,
|
||||
) -> list[Kernel]:
|
||||
kernels_for_provider = []
|
||||
for kernel_cls in cls._kernel_classes:
|
||||
kernels_for_provider.extend(
|
||||
kernel_cls.generate_kernels(
|
||||
metadata_filter, epilogue_args=epilogue_args, cc=cc
|
||||
)
|
||||
)
|
||||
|
||||
return kernels_for_provider
|
||||
|
||||
@classmethod
|
||||
def register(cls, kernel_class: type[Kernel]) -> type[Kernel]:
|
||||
cls._kernel_classes.append(kernel_class)
|
||||
return kernel_class
|
||||
|
||||
# Imports for side effects (kernel registration)
|
||||
import cutlass_api.providers.cutedsl.elementwise # noqa: F401
|
||||
import cutlass_api.providers.cutedsl.gemm # noqa: F401
|
||||
|
||||
__all__ = ["CuTeDSLProvider"]
|
||||
else:
|
||||
__all__ = []
|
||||
@@ -0,0 +1,32 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
# Importing is done for the side effect of registering the kernel classes with the provider.
|
||||
# Suppress linter warnings about unused import.
|
||||
# ruff: noqa: F401
|
||||
import cutlass_api.providers.cutedsl.elementwise.elementwise_add
|
||||
@@ -0,0 +1,240 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
from collections.abc import Callable
|
||||
from itertools import product
|
||||
from typing import ClassVar
|
||||
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
from cutlass_api.arguments import ElementwiseArguments, EpilogueArguments
|
||||
from cutlass_api.artifact import CompiledArtifact
|
||||
from cutlass_api.metadata import (
|
||||
ElementwiseOperandsMetadata,
|
||||
KernelMetadata,
|
||||
TensorAttributes,
|
||||
)
|
||||
from cutlass_api.providers.cutedsl import CuTeDSLProvider
|
||||
from cutlass_api.providers.cutedsl.kernel import CuteDslKernel
|
||||
from cutlass_api.utils import to_cuda_stream
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
|
||||
"""
|
||||
An Elementwise Addition kernel using CuTe DSL:
|
||||
out = A + B
|
||||
"""
|
||||
|
||||
|
||||
@CuTeDSLProvider.register
|
||||
class ElementwiseAddKernel(CuteDslKernel):
|
||||
_supported_dtypes: ClassVar[list[type[cutlass.Numeric]]] = [
|
||||
cutlass.Float32,
|
||||
cutlass.Float16,
|
||||
]
|
||||
|
||||
def __init__(self, metadata: KernelMetadata):
|
||||
self.metadata = metadata
|
||||
self.impl = ElementwiseAddKernelImpl()
|
||||
|
||||
def compile(self, args: ElementwiseArguments, cc: int = None) -> CompiledArtifact:
|
||||
stream = cutlass.cute.runtime.make_fake_stream()
|
||||
compiled_kernel = self.cute_compile(self.impl, args.A, args.B, args.out, stream)
|
||||
return CompiledArtifact(compiled_kernel, self)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
args: ElementwiseArguments,
|
||||
compiled_artifact: CompiledArtifact,
|
||||
stream,
|
||||
workspace=None,
|
||||
) -> None:
|
||||
stream = to_cuda_stream(stream)
|
||||
compiled_kernel = compiled_artifact.compiled_obj
|
||||
self.cute_run(compiled_kernel, args.A, args.B, args.out, stream)
|
||||
|
||||
@staticmethod
|
||||
def generate_kernels(
|
||||
metadata_filter: Callable[[KernelMetadata], bool],
|
||||
epilogue_args: EpilogueArguments = None,
|
||||
cc: int = None,
|
||||
) -> list["ElementwiseAddKernel"]:
|
||||
if epilogue_args is not None:
|
||||
return []
|
||||
|
||||
min_cc = 80
|
||||
if cc is not None and cc < min_cc:
|
||||
return []
|
||||
|
||||
stride_names = {
|
||||
(0, 1): "t", # row major
|
||||
(1, 0): "n", # column major
|
||||
}
|
||||
|
||||
kernel_list = []
|
||||
for dtype in ElementwiseAddKernel._supported_dtypes:
|
||||
alignment = 128 // dtype.width
|
||||
for stride_A, stride_B, stride_out in product(
|
||||
stride_names.keys(), repeat=3
|
||||
):
|
||||
kernel_name = (
|
||||
"cutedsl.ElementwiseAddKernel"
|
||||
+ f"_A{dtype}_{stride_names[stride_A]}"
|
||||
+ f"_B{dtype}_{stride_names[stride_B]}"
|
||||
+ f"_out{dtype}_{stride_names[stride_out]}"
|
||||
)
|
||||
|
||||
operands = ElementwiseOperandsMetadata(
|
||||
A=TensorAttributes(
|
||||
dtype=dtype, stride=stride_A, alignment=alignment
|
||||
),
|
||||
B=TensorAttributes(
|
||||
dtype=dtype, stride=stride_B, alignment=alignment
|
||||
),
|
||||
out=TensorAttributes(
|
||||
dtype=dtype, stride=stride_out, alignment=alignment
|
||||
),
|
||||
)
|
||||
metadata = KernelMetadata(
|
||||
operands=operands, kernel_name=kernel_name, kernel_class=ElementwiseAddKernel, min_cc=min_cc
|
||||
)
|
||||
if metadata_filter(metadata):
|
||||
kernel_list.append(ElementwiseAddKernel(metadata))
|
||||
|
||||
return kernel_list
|
||||
|
||||
|
||||
class ElementwiseAddKernelImpl:
|
||||
@cute.jit
|
||||
def __call__(
|
||||
self, mA: cute.Tensor, mB: cute.Tensor, mC: cute.Tensor, stream: cuda.CUstream
|
||||
):
|
||||
copy_bits: cutlass.Constexpr = 128
|
||||
dtype = mA.element_type
|
||||
vector_size = copy_bits // dtype.width
|
||||
|
||||
thr_layout = cute.make_ordered_layout((4, 32), order=(1, 0))
|
||||
val_layout = cute.make_ordered_layout((4, vector_size), order=(1, 0))
|
||||
tiler_mn, tv_layout = cute.make_layout_tv(thr_layout, val_layout)
|
||||
|
||||
gA = cute.zipped_divide(mA, tiler_mn) # ((TileM,TileN),(RestM,RestN))
|
||||
gB = cute.zipped_divide(mB, tiler_mn) # ((TileM,TileN),(RestM,RestN))
|
||||
gC = cute.zipped_divide(mC, tiler_mn) # ((TileM,TileN),(RestM,RestN))
|
||||
|
||||
idC = cute.make_identity_tensor(mC.shape)
|
||||
cC = cute.zipped_divide(idC, tiler=tiler_mn)
|
||||
|
||||
self.kernel(gA, gB, gC, cC, mC.shape, thr_layout, val_layout).launch(
|
||||
grid=[cute.size(gC, mode=[1]), 1, 1],
|
||||
block=[cute.size(tv_layout, mode=[0]), 1, 1],
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
@cute.kernel
|
||||
def kernel(
|
||||
self,
|
||||
gA: cute.Tensor,
|
||||
gB: cute.Tensor,
|
||||
gC: cute.Tensor,
|
||||
cC: cute.Tensor, # coordinate tensor
|
||||
shape: cute.Shape,
|
||||
thr_layout: cute.Layout,
|
||||
val_layout: cute.Layout,
|
||||
):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
bidx, _, _ = cute.arch.block_idx()
|
||||
|
||||
# slice for CTAs
|
||||
# logical id -> address
|
||||
blk_coord = ((None, None), bidx)
|
||||
blkA = gA[blk_coord] # (TileM,TileN)
|
||||
blkB = gB[blk_coord] # (TileM,TileN)
|
||||
blkC = gC[blk_coord] # (TileM,TileN)
|
||||
blkCrd = cC[blk_coord] # (TileM, TileN)
|
||||
|
||||
# declare the atoms which will be used later for memory copy
|
||||
copy_atom_load = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(), gA.element_type
|
||||
)
|
||||
copy_atom_store = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(), gC.element_type
|
||||
)
|
||||
|
||||
tiled_copy_A = cute.make_tiled_copy_tv(copy_atom_load, thr_layout, val_layout)
|
||||
tiled_copy_B = cute.make_tiled_copy_tv(copy_atom_load, thr_layout, val_layout)
|
||||
tiled_copy_C = cute.make_tiled_copy_tv(copy_atom_store, thr_layout, val_layout)
|
||||
|
||||
thr_copy_A = tiled_copy_A.get_slice(tidx)
|
||||
thr_copy_B = tiled_copy_B.get_slice(tidx)
|
||||
thr_copy_C = tiled_copy_C.get_slice(tidx)
|
||||
|
||||
thrA = thr_copy_A.partition_S(blkA)
|
||||
thrB = thr_copy_B.partition_S(blkB)
|
||||
thrC = thr_copy_C.partition_S(blkC)
|
||||
|
||||
# allocate fragments for gmem->rmem
|
||||
frgA = cute.make_fragment_like(thrA)
|
||||
frgB = cute.make_fragment_like(thrB)
|
||||
frgC = cute.make_fragment_like(thrC)
|
||||
|
||||
thrCrd = thr_copy_C.partition_S(blkCrd)
|
||||
frgPred = cute.make_rmem_tensor(thrCrd.shape, cutlass.Boolean)
|
||||
|
||||
for i in range(0, cute.size(frgPred), 1):
|
||||
val = cute.elem_less(thrCrd[i], shape)
|
||||
frgPred[i] = val
|
||||
|
||||
# Print per thread predicate mask
|
||||
# if tidx == 0 and bidx == 0:
|
||||
# cute.printf("block_dim = {}", cute.arch.grid_dim())
|
||||
# cute.printf("shape = {}", shape)
|
||||
# cute.print_tensor(thrA)
|
||||
# cute.print_tensor(thrB)
|
||||
# cute.print_tensor(frgPred)
|
||||
|
||||
##########################################################
|
||||
# Move data to reg address space
|
||||
##########################################################
|
||||
|
||||
cute.copy(copy_atom_load, thrA, frgA, pred=frgPred)
|
||||
cute.copy(copy_atom_load, thrB, frgB, pred=frgPred)
|
||||
|
||||
# if tidx == 0 and bidx == 0:
|
||||
# cute.print_tensor(frgA)
|
||||
# cute.print_tensor(frgB)
|
||||
|
||||
# Load data before use. The compiler will optimize the copy and load
|
||||
# operations to convert some memory ld/st into register uses.
|
||||
result = frgA.load() + frgB.load()
|
||||
|
||||
# Save the results back to registers. Here we reuse b's registers.
|
||||
frgC.store(result)
|
||||
|
||||
# Copy the results back to c
|
||||
cute.copy(copy_atom_store, frgC, thrC, pred=frgPred)
|
||||
1056
python/cutlass_api/cutlass_api/providers/cutedsl/evt/common_efc.py
Normal file
1056
python/cutlass_api/cutlass_api/providers/cutedsl/evt/common_efc.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,333 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
from cutlass_api.fusion.ir import DAGIR, LoadNode, StoreNode, TopoVisitorNode
|
||||
from cutlass_api.fusion.ir.load_nodes import (
|
||||
ColumnBroadcastImpl,
|
||||
RowBroadcastImpl,
|
||||
ScalarBroadcastImpl,
|
||||
)
|
||||
from cutlass_api.fusion.ir.store_nodes import ReductionImplBase
|
||||
from cutlass_api.fusion.library import ActivationOp, FunctionalOp
|
||||
from cutlass_api.providers.cutedsl.evt import common_efc
|
||||
from cutlass_api.status import Status
|
||||
|
||||
OpToCuteImplStr = {
|
||||
FunctionalOp.Exp: lambda x: f"exp({x})",
|
||||
ActivationOp.ReLU: lambda x: f"relu({x})",
|
||||
ActivationOp.Sigmoid: lambda x: f"sigmoid({x})",
|
||||
ActivationOp.Tanh: lambda x: f"tanh({x})",
|
||||
FunctionalOp.Divides: lambda x, y: f"({x} / {y})",
|
||||
FunctionalOp.Multiplies: lambda x, y: f"({x} * {y})",
|
||||
FunctionalOp.Plus: lambda x, y: f"({x} + {y})",
|
||||
FunctionalOp.Minus: lambda x, y: f"({x} - {y})",
|
||||
}
|
||||
|
||||
|
||||
# Functions to eventually be run as part of the EFC function.
|
||||
# Every function takes in `efc_config` as the first argument (even if it
|
||||
# is not used). This is necssary for running analysis
|
||||
# passes on the EFC function absent an MLIR context (which would be
|
||||
# needed is we used `cute.exp` directly).
|
||||
|
||||
OpToCuteImpl = {
|
||||
FunctionalOp.Exp: lambda efc_config, x: efc_config.exp(x),
|
||||
ActivationOp.ReLU: lambda efc_config, x: efc_config.where(
|
||||
x > 0, x, efc_config.full_like(x, 0)
|
||||
),
|
||||
ActivationOp.Sigmoid: lambda efc_config, x: 1.0 / (1.0 + efc_config.exp(-x)),
|
||||
ActivationOp.Tanh: lambda efc_config, x: efc_config.tanh(x),
|
||||
FunctionalOp.Divides: lambda efc_config, x, y: x / y,
|
||||
FunctionalOp.Multiplies: lambda efc_config, x, y: x * y,
|
||||
FunctionalOp.Plus: lambda efc_config, x, y: x + y,
|
||||
FunctionalOp.Minus: lambda efc_config, x, y: x - y,
|
||||
}
|
||||
|
||||
|
||||
def store(efc_config, x, y):
|
||||
x.store(y)
|
||||
|
||||
|
||||
def load(efc_config, x):
|
||||
return x.load()
|
||||
|
||||
|
||||
def get_val(x):
|
||||
return lambda efc_config: x
|
||||
|
||||
|
||||
class EFCConverter:
|
||||
"""
|
||||
Helper class to translate from DAGIR to the CuTe DSL epilogue visitor tree (EVT) structure
|
||||
|
||||
The CuTe DSL EVT structure is as follows for a (alpha * accum + beta * C) epilogue:
|
||||
def epi(self, C, alpha, beta, D):
|
||||
C_val = C.load()
|
||||
alpha_val = alpha.load()
|
||||
beta_val = beta.load()
|
||||
compute_0_val = alpha_val * self.accum()
|
||||
compute_1_val = beta_val * C_val
|
||||
compute_2_val = compute_0_val + compute_1_val
|
||||
D.store(compute_2_val)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def convert(dag_ir: DAGIR, parameter_names: list[str]) -> Callable:
|
||||
"""
|
||||
Converts the DAGIR to a callable epilogue function supported by CuTe DSL EFC.
|
||||
|
||||
The simplest way to do this would be to convert the DAGIR into the equivalent
|
||||
string representation of the EFC function and call `exec` on it to get the callable.
|
||||
However, this is generally not considered safe as it can potentially open avenues
|
||||
for allowing arbitrary code execution.
|
||||
|
||||
Instead, we define a generic configurable epilogue that we specialize based
|
||||
on the DAGIR itself. The generic epilogue takes in a list of parameters
|
||||
and executes a predefined sequence of operations. Each operation is executed
|
||||
in order and places its result on a stack.
|
||||
|
||||
Outside this generic epilogue, we define the sequence of operations that
|
||||
are to be executed and determine where the sources for such operations live
|
||||
(either their index in the parameter list or on the stack).
|
||||
|
||||
Example
|
||||
=======
|
||||
Suppose we have the following epilogue:
|
||||
```
|
||||
def epi(accum, alpha):
|
||||
D = accum * alpha
|
||||
return D
|
||||
```
|
||||
A simple string representation of the epilogue in EFC format would be:
|
||||
|
||||
```
|
||||
def efc_epi(efc_config, alpha, D):
|
||||
accum_val = efc_config.accum()
|
||||
alpha_val = alpha.load()
|
||||
temp = accum_val * alpha_val
|
||||
D.store(temp)
|
||||
```
|
||||
|
||||
We would like to generalize this to a function with signature:
|
||||
```
|
||||
def efc_epi(efc_config, *parameters):
|
||||
```
|
||||
that can perform arbitrary operations.
|
||||
|
||||
We would know from `parameter_names` that the list of parameters provided to the EFC epilogue
|
||||
will be [alpha, D]. We then traverse the DAGIR and see the following operations in order:
|
||||
|
||||
LoadNode(accum)
|
||||
LoadNode(alpha)
|
||||
ComputeNode(mul, alpha, accum) -> temp
|
||||
StoreNode(temp, D)
|
||||
|
||||
We can convert this into a series of corresponding operations that operate on variables
|
||||
corresponding to indices. Indices represent either a position in the parameter list
|
||||
or a position on the stack (index < len(parameters) implying a position in the parameter list).
|
||||
In this example, [alpha, D] correspond to indices 0 and 1, respectively.
|
||||
|
||||
Traversing each DAGIR node above, we get:
|
||||
LoadNode(accum) -> efc_config.accum(), result is in stack[0] (index 2)
|
||||
LoadNode(alpha) -> alpha.load()
|
||||
-> parameter[0].load(), result is in stack[1] (index 3)
|
||||
ComputeNode(mul, alpha, accum) -> stack[1] * stack[0] (index 3 * index 2), result is in stack[2] (index 4)
|
||||
StoreNode(temp, D) -> D.store(stack[2])
|
||||
-> paramter[1].store(stack[2])
|
||||
|
||||
This is encoded as the following tuples:
|
||||
```
|
||||
ops = [
|
||||
# Load of accum omitted because it is performed automatically in the generic epilogue
|
||||
(load, 0),
|
||||
(mul, 3, 2),
|
||||
(store, 1, 2),
|
||||
]
|
||||
```
|
||||
|
||||
The generic epilogue then simply performs:
|
||||
```
|
||||
def efc_epi(efc_config, *parameters):
|
||||
stack = [efc_config.accum()]
|
||||
for op in ops:
|
||||
fn = op[0]
|
||||
inputs = [get(idx) for idx in op[1:]]
|
||||
stack.append(fn(efc_config, *inputs))
|
||||
```
|
||||
Where the `get()` function is a helper that returns either a value from the parameter list or the stack,
|
||||
depending on the index.
|
||||
"""
|
||||
# Provide a unique identifier for cases in which `accum` is also being written out
|
||||
# We use an integer since parameter_names is a list of strings -- -1 is guaranteed
|
||||
# not to be in the parameter_names list, but can still be used as a key in the dictionary.
|
||||
accum_out_name = -1
|
||||
|
||||
# If 'accum' is in parameter_names, we know that it must be because
|
||||
# the accumulator is also being written out
|
||||
accum_out_loc = -1
|
||||
if "accum" in parameter_names:
|
||||
# Rename the output version temporarily so as not to confuse with the
|
||||
# input accumulator
|
||||
accum_out_loc = parameter_names.index("accum")
|
||||
parameter_names[accum_out_loc] = accum_out_name
|
||||
|
||||
name_to_idx = {}
|
||||
for idx, name in enumerate(parameter_names):
|
||||
name_to_idx[name] = idx
|
||||
|
||||
cur_idx = len(name_to_idx)
|
||||
|
||||
def add_name(name: str):
|
||||
nonlocal cur_idx
|
||||
name_to_idx[name] = cur_idx
|
||||
cur_idx += 1
|
||||
|
||||
def idx(name: str):
|
||||
val = name_to_idx[name]
|
||||
if isinstance(val, str):
|
||||
return idx(val)
|
||||
return val
|
||||
|
||||
add_name("accum")
|
||||
|
||||
# Each entry is a tuple containing a load/store/compute op and operands needed for it.
|
||||
ops = []
|
||||
debug_string_ops = []
|
||||
|
||||
for meta in dag_ir.node_metas_topological_order():
|
||||
if isinstance(meta, LoadNode) and getattr(meta, "is_output", False):
|
||||
assert meta.name == "accum"
|
||||
# Handle the special case where the accumulator is also being written out
|
||||
# This occurs in DAG IR as a load node with is_output = True
|
||||
ops.append((store, idx(accum_out_name), idx("accum")))
|
||||
debug_string_ops.append("accum_out_name.store(accum)")
|
||||
cur_idx += 1
|
||||
if isinstance(meta, LoadNode):
|
||||
# Add new values to the stack for any operations that need to be .load'ed.
|
||||
# This includes any non-scalar input parameters
|
||||
is_scalar = isinstance(meta.underlying_impl, ScalarBroadcastImpl)
|
||||
is_param_scalar = is_scalar and meta.name in parameter_names
|
||||
|
||||
if meta.name != "accum" and not is_param_scalar:
|
||||
if is_scalar:
|
||||
ops.append((get_val(meta.tensor.value),))
|
||||
add_name(meta.name)
|
||||
else:
|
||||
ops.append((load, idx(meta.name)))
|
||||
debug_string_ops.append(f"{meta.name} = {meta.name}.load()")
|
||||
|
||||
# Update the entry in name_to_idx to the index of the loaded value
|
||||
add_name(meta.name)
|
||||
elif isinstance(meta, StoreNode):
|
||||
children = dag_ir.get_all_inputs(meta.name)
|
||||
if len(children) != 1:
|
||||
raise ValueError(
|
||||
f"Store node {meta.name} has {len(children)} children, but only one is supported"
|
||||
)
|
||||
child = children[0]
|
||||
ops.append((store, idx(meta.name), idx(child)))
|
||||
debug_string_ops.append(f"{meta.name}.store({child})")
|
||||
|
||||
# We want to map the following sequence:
|
||||
# def epi(accum, x, y):
|
||||
# out0 = accum * x
|
||||
# out1 = out0 * y
|
||||
# return out0, out1
|
||||
# To:
|
||||
# x_val = x.load()
|
||||
# y_val = y.load()
|
||||
# c0 = self.Acc() * x_val
|
||||
# out0.store(c0)
|
||||
# c1 = c0 * y_val
|
||||
# out1.store(c1)
|
||||
# To support computation on top of out0, we need to remap it to c0 so that the next computation can use it.
|
||||
name_to_idx[meta.name] = child
|
||||
cur_idx += 1
|
||||
elif isinstance(meta, TopoVisitorNode):
|
||||
raise ValueError(f"TopoVisitorNode {meta.name} is not supported")
|
||||
else:
|
||||
if dag_ir.in_degree(meta.name) == 0:
|
||||
continue
|
||||
|
||||
sorted_children = [
|
||||
idx(child_name) for child_name in dag_ir.get_all_inputs(meta.name)
|
||||
]
|
||||
entry = (OpToCuteImpl[meta.fn], *sorted_children)
|
||||
ops.append(entry)
|
||||
debug_string_ops.append(
|
||||
f"{meta.name} = {OpToCuteImplStr[meta.fn](*dag_ir.get_all_inputs(meta.name))}"
|
||||
)
|
||||
add_name(meta.name)
|
||||
|
||||
def epi(efc_config, *parameters):
|
||||
stack = [efc_config.accum()]
|
||||
|
||||
def get(idx: int):
|
||||
if idx < len(parameters):
|
||||
return parameters[idx]
|
||||
else:
|
||||
return stack[idx - len(parameters)]
|
||||
|
||||
for i, op in enumerate(ops):
|
||||
fn = op[0]
|
||||
inputs = [get(idx) for idx in op[1:]]
|
||||
stack.append(fn(efc_config, *inputs))
|
||||
|
||||
if accum_out_loc != -1:
|
||||
# Restore the original name of the accumulator
|
||||
parameter_names[accum_out_loc] = "accum"
|
||||
|
||||
named_epi = common_efc.create_named_epilogue(["self", *parameter_names], epi)
|
||||
return named_epi
|
||||
|
||||
@staticmethod
|
||||
def supports(dag_ir: DAGIR) -> Status:
|
||||
"""
|
||||
Checks if the DAGIR is supported by CuTe DSL EFC.
|
||||
"""
|
||||
# Currently do not support TopVisitorNode, row/col broadcasts, and reductions
|
||||
for meta in dag_ir.node_metas_topological_order():
|
||||
if isinstance(meta, TopoVisitorNode):
|
||||
return Status.fail("TopoVisitorNode is not supported")
|
||||
if isinstance(meta.underlying_impl, RowBroadcastImpl):
|
||||
return Status.fail("RowBroadcastImpl is not supported")
|
||||
if isinstance(meta.underlying_impl, ColumnBroadcastImpl):
|
||||
return Status.fail("ColumnBroadcastImpl is not supported")
|
||||
if isinstance(meta.underlying_impl, ReductionImplBase):
|
||||
return Status.fail("ReductionImplBase is not supported")
|
||||
return Status.success()
|
||||
|
||||
@staticmethod
|
||||
def identity_efc(self, D):
|
||||
D.store(self.accum())
|
||||
@@ -0,0 +1,33 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
# Importing is done for the side effect of registering the kernel classes with the provider.
|
||||
# Suppress linter warnings about unused import.
|
||||
# ruff: noqa: F401
|
||||
import cutlass_api.providers.cutedsl.gemm.sm100_static_persistent
|
||||
import cutlass_api.providers.cutedsl.gemm.sm100_static_persistent_efc
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,414 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import itertools
|
||||
from collections.abc import Callable, Generator
|
||||
|
||||
import cutlass
|
||||
|
||||
from cutlass_api.arguments import (
|
||||
EpilogueArguments,
|
||||
GemmArguments,
|
||||
)
|
||||
from cutlass_api.artifact import CompiledArtifact
|
||||
from cutlass_api.metadata import (
|
||||
GemmOperandsMetadata,
|
||||
KernelMetadata,
|
||||
Sm100DesignMetadata,
|
||||
TensorAttributes,
|
||||
)
|
||||
from cutlass_api.providers.cutedsl import CuTeDSLProvider
|
||||
from cutlass_api.providers.cutedsl.kernel import CuteDslKernel
|
||||
from cutlass_api.providers.cutedsl.utils import get_max_active_clusters
|
||||
from cutlass_api.status import Status
|
||||
from cutlass_api.utils import strides_to_layout_string, to_cuda_stream, tuple_to_string
|
||||
|
||||
from .implementations.sm100_static_persistent_impl import PersistentDenseGemmKernelImpl
|
||||
|
||||
|
||||
@CuTeDSLProvider.register
|
||||
class PersistentDenseGemmKernel(CuteDslKernel):
|
||||
"""This class implements batched matrix multiplication (C = A @ B) with support for various data types
|
||||
and architectural features specific to Blackwell GPUs with persistent tile scheduling and warp specialization.
|
||||
|
||||
:note: In current version, A and B tensor must have the same data type
|
||||
- i.e., Float8E4M3FN for A and Float8E5M2 for B is not supported
|
||||
|
||||
:note: Supported A/B data types:
|
||||
- TFloat32
|
||||
- Float16/BFloat16
|
||||
- Int8/Uint8
|
||||
- Float8E4M3FN/Float8E5M2
|
||||
|
||||
:note: Supported accumulator data types:
|
||||
- Float32 (for all floating point A/B data types)
|
||||
- Float16 (only for fp16 and fp8 A/B data types)
|
||||
- Int32 (only for uint8/int8 A/B data types)
|
||||
|
||||
:note: Supported C data types:
|
||||
- Float32 (for float32 and int32 accumulator data types)
|
||||
- Int32 (for float32 and int32 accumulator data types)
|
||||
- Float16/BFloat16 (for fp16 and fp8 accumulator data types)
|
||||
- Int8/Uint8 (for uint8/int8 accumulator data types)
|
||||
- Float8E4M3FN/Float8E5M2 (for float32 accumulator data types)
|
||||
|
||||
:note: Constraints:
|
||||
- MMA tiler M must be 64/128 (use_2cta_instrs=False) or 128/256 (use_2cta_instrs=True)
|
||||
- MMA tiler N must be 32-256, step 32
|
||||
- Cluster shape M must be multiple of 2 if use_2cta_instrs=True
|
||||
- Cluster shape M/N must be positive and power of 2, total cluster size <= 16
|
||||
"""
|
||||
|
||||
def __init__(self, metadata: KernelMetadata):
|
||||
self.metadata = metadata
|
||||
|
||||
def epilogue_op(x):
|
||||
return x
|
||||
|
||||
mma_tiler_mn = (metadata.design.tile_shape[0], metadata.design.tile_shape[1])
|
||||
cluster_shape_mn = (
|
||||
metadata.design.cluster_shape[0],
|
||||
metadata.design.cluster_shape[1],
|
||||
)
|
||||
|
||||
self.impl = PersistentDenseGemmKernelImpl(
|
||||
metadata.operands.accumulator_type,
|
||||
metadata.design.use_2cta_mma,
|
||||
mma_tiler_mn,
|
||||
cluster_shape_mn,
|
||||
metadata.design.use_tma_store,
|
||||
epilogue_op,
|
||||
)
|
||||
|
||||
def _supports(self, args: GemmArguments) -> Status:
|
||||
if args.epilogue is not None:
|
||||
return Status.fail("This kernel does not support any epilogue fusion.")
|
||||
return Status.success()
|
||||
|
||||
def compile(self, args: GemmArguments, cc: int = None) -> CompiledArtifact:
|
||||
stream = cutlass.cute.runtime.make_fake_stream()
|
||||
max_active_clusters = get_max_active_clusters(self.impl.cluster_shape_mn)
|
||||
|
||||
compiled_kernel = self.cute_compile(
|
||||
self.impl,
|
||||
args.A,
|
||||
args.B,
|
||||
args.out,
|
||||
max_active_clusters,
|
||||
stream,
|
||||
self.impl.epilogue_op,
|
||||
)
|
||||
return CompiledArtifact(compiled_kernel, self)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
args: GemmArguments,
|
||||
compiled_artifact: CompiledArtifact,
|
||||
stream,
|
||||
workspace=None,
|
||||
) -> None:
|
||||
stream = to_cuda_stream(stream)
|
||||
compiled_gemm = compiled_artifact.compiled_obj
|
||||
self.cute_run(compiled_gemm, args.A, args.B, args.out, stream)
|
||||
|
||||
@staticmethod
|
||||
def _valid_operands(operands: GemmOperandsMetadata) -> bool:
|
||||
if not isinstance(operands, GemmOperandsMetadata):
|
||||
return False
|
||||
|
||||
# In current version, A and B tensor must have the same data type
|
||||
# i.e., Float8E4M3FN for A and Float8E5M2 for B is not supported
|
||||
if operands.A.dtype != operands.B.dtype:
|
||||
return False
|
||||
|
||||
abtype = operands.A.dtype
|
||||
|
||||
# Supported A/B data types:
|
||||
# - TFloat32
|
||||
# - Float16/BFloat16
|
||||
# - Int8/Uint8
|
||||
# - Float8E4M3FN/Float8E5M2
|
||||
if abtype not in [
|
||||
cutlass.Float32,
|
||||
cutlass.Float16,
|
||||
cutlass.BFloat16,
|
||||
cutlass.Int8,
|
||||
cutlass.Uint8,
|
||||
cutlass.Float8E4M3FN,
|
||||
cutlass.Float8E5M2,
|
||||
]:
|
||||
return False
|
||||
|
||||
# Supported accumulator data types:
|
||||
# - Float32 (for all floating point A/B data types)
|
||||
# - Float16 (only for fp16 and fp8 A/B data types)
|
||||
# - Int32 (only for uint8/int8 A/B data types)
|
||||
if operands.accumulator_type == cutlass.Float32:
|
||||
if not abtype.is_float:
|
||||
return False
|
||||
elif operands.accumulator_type == cutlass.Float16:
|
||||
if abtype not in [
|
||||
cutlass.Float16,
|
||||
cutlass.Float8E4M3FN,
|
||||
cutlass.Float8E5M2,
|
||||
]:
|
||||
return False
|
||||
elif operands.accumulator_type == cutlass.Int32:
|
||||
if abtype not in [cutlass.Uint8, cutlass.Int8]:
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
|
||||
# Supported out data types:
|
||||
# - Float32 (for float32 and int32 accumulator data types)
|
||||
# - Int32 (for float32 and int32 accumulator data types)
|
||||
# - Float16/BFloat16 (for fp16 and fp8 accumulator data types)
|
||||
# - Int8/Uint8 (for uint8/int8 accumulator data types)
|
||||
# - Float8E4M3FN/Float8E5M2 (for float32 accumulator data types)
|
||||
if operands.out.dtype == cutlass.Float32 or operands.out.dtype == cutlass.Int32:
|
||||
if operands.accumulator_type not in [cutlass.Float32, cutlass.Int32]:
|
||||
return False
|
||||
elif (
|
||||
operands.out.dtype == cutlass.Float16
|
||||
or operands.out.dtype == cutlass.BFloat16
|
||||
):
|
||||
if operands.accumulator_type not in [cutlass.Float16, cutlass.BFloat16]:
|
||||
return False
|
||||
elif operands.out.dtype == cutlass.Int8 or operands.out.dtype == cutlass.Uint8:
|
||||
if operands.accumulator_type not in [cutlass.Int32]:
|
||||
return False
|
||||
elif (
|
||||
operands.out.dtype == cutlass.Float8E4M3FN
|
||||
or operands.out.dtype == cutlass.Float8E5M2
|
||||
):
|
||||
if operands.accumulator_type not in [cutlass.Float32]:
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _metadata_operand_combinations() -> Generator[GemmOperandsMetadata, None, None]:
|
||||
"""
|
||||
Generator that yields all valid GemmOperandsMetadata combinations
|
||||
based on the validation rules in _valid_operands.
|
||||
"""
|
||||
# Supported A/B data types (must be the same)
|
||||
ab_dtypes = [
|
||||
cutlass.Float32,
|
||||
cutlass.Float16,
|
||||
cutlass.BFloat16,
|
||||
cutlass.Int8,
|
||||
cutlass.Uint8,
|
||||
cutlass.Float8E4M3FN,
|
||||
cutlass.Float8E5M2,
|
||||
]
|
||||
|
||||
row_major_stride = (0, 0, 1)
|
||||
col_major_stride = (0, 1, 0)
|
||||
alignment = 16
|
||||
|
||||
for ab_dtype in ab_dtypes:
|
||||
# Determine valid accumulator types for this A/B dtype
|
||||
valid_acc_dtypes = []
|
||||
|
||||
if (
|
||||
ab_dtype.is_float
|
||||
): # Float32, Float16, BFloat16, Float8E4M3FN, Float8E5M2
|
||||
valid_acc_dtypes.append(cutlass.Float32)
|
||||
if ab_dtype in [
|
||||
cutlass.Float16,
|
||||
cutlass.Float8E4M3FN,
|
||||
cutlass.Float8E5M2,
|
||||
]:
|
||||
valid_acc_dtypes.append(cutlass.Float16)
|
||||
else: # Int8, Uint8
|
||||
valid_acc_dtypes.append(cutlass.Int32)
|
||||
|
||||
for acc_dtype in valid_acc_dtypes:
|
||||
# Determine valid output types for this accumulator type
|
||||
valid_out_dtypes = []
|
||||
|
||||
if acc_dtype == cutlass.Float32:
|
||||
valid_out_dtypes.extend(
|
||||
[
|
||||
cutlass.Float8E4M3FN,
|
||||
cutlass.Float8E5M2,
|
||||
cutlass.Float16,
|
||||
cutlass.BFloat16,
|
||||
cutlass.Float32,
|
||||
cutlass.Int32,
|
||||
]
|
||||
)
|
||||
elif acc_dtype == cutlass.Int32:
|
||||
valid_out_dtypes.extend(
|
||||
[cutlass.Int8, cutlass.Uint8, cutlass.Float32, cutlass.Int32]
|
||||
)
|
||||
elif acc_dtype == cutlass.Float16:
|
||||
valid_out_dtypes.extend([cutlass.Float16])
|
||||
|
||||
for out_dtype in valid_out_dtypes:
|
||||
for stride_A, stride_B, stride_out in itertools.product(
|
||||
[row_major_stride, col_major_stride], repeat=3
|
||||
):
|
||||
# Create TensorAttributes for A, B, and out tensors
|
||||
a_attrs = TensorAttributes(
|
||||
dtype=ab_dtype, stride=stride_A, alignment=alignment
|
||||
)
|
||||
b_attrs = TensorAttributes(
|
||||
dtype=ab_dtype, stride=stride_B, alignment=alignment
|
||||
)
|
||||
out_attrs = TensorAttributes(
|
||||
dtype=out_dtype, stride=stride_out, alignment=alignment
|
||||
)
|
||||
|
||||
# Create and yield the GemmOperandsMetadata
|
||||
operands = GemmOperandsMetadata(
|
||||
A=a_attrs,
|
||||
B=b_attrs,
|
||||
out=out_attrs,
|
||||
accumulator_type=acc_dtype,
|
||||
)
|
||||
|
||||
yield operands
|
||||
|
||||
@staticmethod
|
||||
def _valid_metadata(metadata: KernelMetadata) -> bool:
|
||||
if not PersistentDenseGemmKernel._valid_operands(metadata.operands):
|
||||
return False
|
||||
|
||||
design = metadata.design
|
||||
if not isinstance(design, Sm100DesignMetadata):
|
||||
return False
|
||||
|
||||
cluster_size_m, cluster_size_n, _ = design.cluster_shape
|
||||
|
||||
if cluster_size_m % 2 != 0 and cluster_size_m != 1:
|
||||
return False
|
||||
if cluster_size_n % 2 != 0 and cluster_size_n != 1:
|
||||
return False
|
||||
if cluster_size_m * cluster_size_n > 16:
|
||||
return False
|
||||
|
||||
# Constraints based on whether 2CTA instructions are used
|
||||
if design.use_2cta_mma is not None:
|
||||
if design.use_2cta_mma:
|
||||
if cluster_size_m % 2 != 0:
|
||||
return False
|
||||
if design.tile_shape is not None and design.tile_shape[0] not in [
|
||||
128,
|
||||
256,
|
||||
]:
|
||||
return False
|
||||
else:
|
||||
if design.tile_shape is not None and design.tile_shape[0] not in [
|
||||
64,
|
||||
128,
|
||||
]:
|
||||
return False
|
||||
|
||||
if design.tile_shape is not None and design.tile_shape[1] not in range(
|
||||
32, 256, 32
|
||||
):
|
||||
return False
|
||||
|
||||
if metadata.epilogue is not None:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def generate_kernels(
|
||||
metadata_filter: Callable[[KernelMetadata], bool],
|
||||
epilogue_args: EpilogueArguments = None,
|
||||
cc: int = None,
|
||||
) -> list["PersistentDenseGemmKernel"]:
|
||||
"""
|
||||
Returns a list of all possible configurations of PersistentDenseGemmKernel that
|
||||
adhere to constraints passed in under kwargs.
|
||||
"""
|
||||
if cc is not None and cc not in [100, 101, 103]:
|
||||
return []
|
||||
|
||||
design_params = {
|
||||
"use_2cta_mma": [True, False],
|
||||
"tile_shape": [
|
||||
(M, N, 256) for M in [64, 128, 256] for N in [32, 64, 128, 256]
|
||||
],
|
||||
"cluster_shape": [
|
||||
(M, N, 1) for M in [1, 2, 4, 8, 16] for N in [1, 2, 4, 8, 16]
|
||||
],
|
||||
"use_tma_store": [True],
|
||||
}
|
||||
|
||||
if epilogue_args is not None:
|
||||
return []
|
||||
|
||||
from itertools import product
|
||||
|
||||
# Get the list of tunable parameter names and their possible values
|
||||
param_names = list(design_params.keys())
|
||||
param_values = [design_params[name] for name in param_names]
|
||||
|
||||
kernel_list = []
|
||||
|
||||
for operands in PersistentDenseGemmKernel._metadata_operand_combinations():
|
||||
for values in product(*param_values):
|
||||
design = Sm100DesignMetadata(**dict(zip(param_names, values)))
|
||||
kernel_name = "cutedsl.PersistentDenseGemmKernel_sm100_{layout}_A{A}_B{B}_out{out}_acc{acc}_{num_cta}cta_cluster{cluster}_tile{tile}{_tma_store}".format(
|
||||
layout=strides_to_layout_string(
|
||||
operands.A.stride, operands.B.stride, operands.out.stride
|
||||
),
|
||||
A=operands.A.dtype,
|
||||
B=operands.B.dtype,
|
||||
out=operands.out.dtype,
|
||||
acc=operands.accumulator_type,
|
||||
num_cta=("2" if design.use_2cta_mma else "1"),
|
||||
cluster=tuple_to_string(design.cluster_shape),
|
||||
tile=tuple_to_string(design.tile_shape),
|
||||
_tma_store="_tma_store" if design.use_tma_store else "",
|
||||
)
|
||||
|
||||
metadata = KernelMetadata(
|
||||
operands=operands,
|
||||
design=design,
|
||||
kernel_name=kernel_name,
|
||||
kernel_class=PersistentDenseGemmKernel,
|
||||
min_cc=100,
|
||||
epilogue=None,
|
||||
)
|
||||
|
||||
if PersistentDenseGemmKernel._valid_metadata(
|
||||
metadata
|
||||
) and metadata_filter(metadata):
|
||||
kernel_list.append(
|
||||
PersistentDenseGemmKernel(metadata)
|
||||
)
|
||||
|
||||
return kernel_list
|
||||
@@ -0,0 +1,443 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import itertools
|
||||
from collections.abc import Callable, Generator
|
||||
|
||||
import cutlass
|
||||
|
||||
from cutlass_api.arguments import (
|
||||
EpilogueArguments,
|
||||
GemmArguments,
|
||||
)
|
||||
from cutlass_api.artifact import CompiledArtifact
|
||||
from cutlass_api.metadata import (
|
||||
EpilogueMetadata,
|
||||
GemmOperandsMetadata,
|
||||
KernelMetadata,
|
||||
Sm100DesignMetadata,
|
||||
TensorAttributes,
|
||||
)
|
||||
from cutlass_api.providers.cutedsl import CuTeDSLProvider
|
||||
from cutlass_api.providers.cutedsl.evt.converter import EFCConverter
|
||||
from cutlass_api.providers.cutedsl.kernel import CuteDslKernel
|
||||
from cutlass_api.providers.cutedsl.utils import get_max_active_clusters
|
||||
from cutlass_api.status import Status
|
||||
from cutlass_api.utils import (
|
||||
TensorWrapper,
|
||||
strides_to_layout_string,
|
||||
to_cuda_stream,
|
||||
tuple_to_string,
|
||||
)
|
||||
|
||||
from .implementations.sm100_static_persistent_efc_impl import (
|
||||
PersistentDenseGemmEFCKernelImpl,
|
||||
)
|
||||
|
||||
|
||||
@CuTeDSLProvider.register
|
||||
class PersistentDenseGemmEFCKernel(CuteDslKernel):
|
||||
"""Base class for batched GEMM with custom epilogue fusion using EFC.
|
||||
|
||||
This class provides the core infrastructure for persistent batched GEMM operations
|
||||
with customizable epilogue fusion. Subclasses define specific epilogue behaviors
|
||||
by providing an epilogue configuration function that describes operations on the
|
||||
accumulator and supplemental tensors.
|
||||
|
||||
The class handles:
|
||||
- GEMM mainloop (A * B computation)
|
||||
- TMA-based memory operations
|
||||
- Warp specialization
|
||||
- Persistent tile scheduling
|
||||
- EFC (Epilogue Fusion Configuration) integration
|
||||
- CLI argument parsing (extensible via CLIParser.more_parsing())
|
||||
- Tensor creation and validation
|
||||
|
||||
:param acc_dtype: Data type for accumulation during MMA computation
|
||||
:type acc_dtype: type[cutlass.Numeric]
|
||||
:param epi_dtype: Data type for epilogue operation
|
||||
:type epi_dtype: type[cutlass.Numeric]
|
||||
:param use_2cta_instrs: Whether to use CTA group 2 for 2CTA MMA instructions
|
||||
:type use_2cta_instrs: bool
|
||||
:param mma_tiler_mn: Shape of the Matrix Multiply-Accumulate (MMA) tile (M,N)
|
||||
:type mma_tiler_mn: tuple[int, int]
|
||||
:param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing
|
||||
:type cluster_shape_mn: tuple[int, int]
|
||||
:param epilogue_configuration_function: Function defining the epilogue behavior via EFC
|
||||
:type epilogue_configuration_function: Callable
|
||||
|
||||
:note: Supported A/B data types:
|
||||
- TFloat32
|
||||
- Float16/BFloat16
|
||||
- Int8/Uint8
|
||||
- Float8E4M3FN/Float8E5M2
|
||||
(A and B must have the same data type)
|
||||
|
||||
:note: Supported accumulator data types:
|
||||
- Float32 (for all floating point A/B data types)
|
||||
- Float16 (only for fp16 and fp8 A/B data types)
|
||||
- Int32 (only for uint8/int8 A/B data types)
|
||||
|
||||
:note: Supported supplemental tensor data types (epilogue-dependent):
|
||||
- Float32 (for float32 and int32 accumulator data types)
|
||||
- Int32 (for float32 and int32 accumulator data types)
|
||||
- Float16/BFloat16 (for fp16 and fp8 accumulator data types)
|
||||
- Int8/Uint8 (for uint8/int8 accumulator data types)
|
||||
- Float8E4M3FN/Float8E5M2 (for float32 accumulator data types)
|
||||
|
||||
:note: Constraints:
|
||||
- MMA tiler M must be 64/128 (use_2cta_instrs=False) or 128/256 (use_2cta_instrs=True)
|
||||
- MMA tiler N must be 32-256, step 32
|
||||
- Cluster shape M must be multiple of 2 if use_2cta_instrs=True
|
||||
- Cluster shape M/N must be positive and power of 2, total cluster size <= 16
|
||||
"""
|
||||
|
||||
_valid_ab_acc_combinations = {
|
||||
cutlass.Float16: {cutlass.Float16, cutlass.Float32},
|
||||
cutlass.BFloat16: {cutlass.Float32},
|
||||
cutlass.TFloat32: {cutlass.Float32},
|
||||
cutlass.Uint8: {cutlass.Int32},
|
||||
cutlass.Int8: {cutlass.Int32},
|
||||
cutlass.Float8E4M3FN: {cutlass.Float16, cutlass.Float32},
|
||||
cutlass.Float8E5M2: {cutlass.Float16, cutlass.Float32},
|
||||
}
|
||||
|
||||
def __init__(self, metadata: KernelMetadata):
|
||||
self.metadata = metadata
|
||||
|
||||
if metadata.epilogue is not None:
|
||||
epilogue_op = EFCConverter.convert(
|
||||
metadata.epilogue.traced_epilogue.dag_ir,
|
||||
metadata.epilogue.parameter_names,
|
||||
)
|
||||
else:
|
||||
epilogue_op = EFCConverter.identity_efc
|
||||
|
||||
mma_tiler_mn = (metadata.design.tile_shape[0], metadata.design.tile_shape[1])
|
||||
cluster_shape_mn = (
|
||||
metadata.design.cluster_shape[0],
|
||||
metadata.design.cluster_shape[1],
|
||||
)
|
||||
|
||||
self.impl = PersistentDenseGemmEFCKernelImpl(
|
||||
metadata.operands.accumulator_type,
|
||||
metadata.operands.out.dtype,
|
||||
metadata.design.use_2cta_mma,
|
||||
mma_tiler_mn,
|
||||
cluster_shape_mn,
|
||||
epilogue_op,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _valid_fusion(fusion: EpilogueMetadata) -> Status:
|
||||
if not isinstance(fusion, EpilogueMetadata):
|
||||
return Status.fail("Unsupported epilogue argument type.")
|
||||
if not (status := EFCConverter.supports(fusion.traced_epilogue.dag_ir)):
|
||||
return status
|
||||
return Status.success()
|
||||
|
||||
@staticmethod
|
||||
def _valid_operands(operands: GemmOperandsMetadata) -> Status:
|
||||
if not isinstance(operands, GemmOperandsMetadata):
|
||||
return False
|
||||
|
||||
if operands.A.dtype != operands.B.dtype:
|
||||
return False
|
||||
|
||||
ab_dtype = operands.A.dtype
|
||||
acc_dtype = operands.accumulator_type
|
||||
|
||||
if ab_dtype not in PersistentDenseGemmEFCKernel._valid_ab_acc_combinations:
|
||||
return False
|
||||
|
||||
# Check compatibility between accumulator type and AB type
|
||||
if (
|
||||
acc_dtype
|
||||
not in PersistentDenseGemmEFCKernel._valid_ab_acc_combinations[ab_dtype]
|
||||
):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _valid_design(design: Sm100DesignMetadata) -> bool:
|
||||
"""
|
||||
Check if the design metadata is valid.
|
||||
|
||||
:param design: The design metadata
|
||||
:type design: Sm100DesignMetadata
|
||||
|
||||
:return: True if the design is valid, False otherwise
|
||||
:rtype: bool
|
||||
"""
|
||||
if not isinstance(design, Sm100DesignMetadata):
|
||||
return False
|
||||
|
||||
use_2cta_instrs = design.use_2cta_mma
|
||||
mma_tiler_mn = design.tile_shape
|
||||
cluster_shape_mn = design.cluster_shape
|
||||
|
||||
# Check invalid mma tile shape M dimension
|
||||
if not (
|
||||
(not use_2cta_instrs and mma_tiler_mn[0] in [64, 128])
|
||||
or (use_2cta_instrs and mma_tiler_mn[0] in [128, 256])
|
||||
):
|
||||
return False
|
||||
|
||||
# Check invalid mma tile shape N dimension
|
||||
if mma_tiler_mn[1] not in range(32, 257, 32):
|
||||
return False
|
||||
# Check illegal cluster shape M dimension
|
||||
if cluster_shape_mn[0] % (2 if use_2cta_instrs else 1) != 0:
|
||||
return False
|
||||
|
||||
def is_power_of_2(x):
|
||||
return x > 0 and (x & (x - 1)) == 0
|
||||
|
||||
# Check invalid cluster shape constraints
|
||||
if cluster_shape_mn[0] * cluster_shape_mn[1] > 16:
|
||||
return False
|
||||
if cluster_shape_mn[0] <= 0 or cluster_shape_mn[1] <= 0:
|
||||
return False
|
||||
if not is_power_of_2(cluster_shape_mn[0]) or not is_power_of_2(
|
||||
cluster_shape_mn[1]
|
||||
):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _valid_metadata(metadata: KernelMetadata) -> bool:
|
||||
if not PersistentDenseGemmEFCKernel._valid_operands(metadata.operands):
|
||||
return False
|
||||
|
||||
if not PersistentDenseGemmEFCKernel._valid_design(metadata.design):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _metadata_operand_combinations() -> Generator[GemmOperandsMetadata, None, None]:
|
||||
"""
|
||||
Generator that yields all valid GemmOperandsMetadata combinations
|
||||
based on the validation rules in _valid_operands.
|
||||
"""
|
||||
row_major_stride = (0, 0, 1)
|
||||
col_major_stride = (0, 1, 0)
|
||||
alignment = 16
|
||||
|
||||
for (
|
||||
ab_dtype,
|
||||
valid_acc_dtypes,
|
||||
) in PersistentDenseGemmEFCKernel._valid_ab_acc_combinations.items():
|
||||
for acc_dtype in valid_acc_dtypes:
|
||||
# Determine valid output types for this accumulator type
|
||||
valid_out_dtypes = []
|
||||
|
||||
if acc_dtype == cutlass.Float32:
|
||||
valid_out_dtypes.extend(
|
||||
[
|
||||
cutlass.Float32,
|
||||
cutlass.Int32,
|
||||
cutlass.Float16,
|
||||
cutlass.BFloat16,
|
||||
]
|
||||
)
|
||||
# Float8 output types only valid with Float32 accumulator
|
||||
valid_out_dtypes.extend([cutlass.Float8E4M3FN, cutlass.Float8E5M2])
|
||||
elif acc_dtype == cutlass.Int32:
|
||||
valid_out_dtypes.extend([cutlass.Float32, cutlass.Int32])
|
||||
# Integer output types only valid with Int32 accumulator
|
||||
valid_out_dtypes.extend([cutlass.Int8, cutlass.Uint8])
|
||||
elif acc_dtype == cutlass.Float16:
|
||||
valid_out_dtypes.extend([cutlass.Float16, cutlass.BFloat16])
|
||||
|
||||
for out_dtype in valid_out_dtypes:
|
||||
for stride_A, stride_B, stride_out in itertools.product(
|
||||
[row_major_stride, col_major_stride], repeat=3
|
||||
):
|
||||
# Create TensorAttributes for A, B, and out tensors
|
||||
a_attrs = TensorAttributes(
|
||||
dtype=ab_dtype, stride=stride_A, alignment=alignment
|
||||
)
|
||||
b_attrs = TensorAttributes(
|
||||
dtype=ab_dtype, stride=stride_B, alignment=alignment
|
||||
)
|
||||
out_attrs = TensorAttributes(
|
||||
dtype=out_dtype, stride=stride_out, alignment=alignment
|
||||
)
|
||||
|
||||
# Create and yield the GemmOperandsMetadata
|
||||
operands = GemmOperandsMetadata(
|
||||
A=a_attrs,
|
||||
B=b_attrs,
|
||||
out=out_attrs,
|
||||
accumulator_type=acc_dtype,
|
||||
)
|
||||
|
||||
yield operands
|
||||
|
||||
def _supports(self, args: GemmArguments) -> Status:
|
||||
if args.epilogue is not None:
|
||||
fusion_metadata = EpilogueMetadata.from_args(args.epilogue)
|
||||
if not self._valid_fusion(fusion_metadata):
|
||||
return Status.fail("Provided epilogue fusion is not supported by this kernel")
|
||||
|
||||
return Status.success()
|
||||
|
||||
def compile(self, args: GemmArguments, cc: int = None) -> CompiledArtifact:
|
||||
stream = cutlass.cute.runtime.make_fake_stream()
|
||||
max_active_clusters = get_max_active_clusters(self.impl.cluster_shape_mn)
|
||||
|
||||
if args.epilogue is not None:
|
||||
epilogue_params = args.epilogue.parameters
|
||||
else:
|
||||
epilogue_params = [args.out]
|
||||
|
||||
epilogue_params = [
|
||||
e.compile_time_tensor if isinstance(e, TensorWrapper) else e
|
||||
for e in epilogue_params
|
||||
]
|
||||
|
||||
# EFC needs special handling for supplemental arguments
|
||||
self.impl.efc.compile(epilogue_params)
|
||||
compiled_gemm = self.cute_compile(
|
||||
self.impl,
|
||||
args.A,
|
||||
args.B,
|
||||
max_active_clusters,
|
||||
stream,
|
||||
self.impl.efc.jit.pack_arguments(*epilogue_params),
|
||||
)
|
||||
|
||||
# Wrap the compiled kernel to handle supplemental argument packing at launch time
|
||||
def wrapped_launch(a_tensor, b_tensor, stream, *supplemental_args):
|
||||
runtime_args = [
|
||||
e.runtime_tensor if isinstance(e, TensorWrapper) else e
|
||||
for e in supplemental_args
|
||||
]
|
||||
return compiled_gemm(
|
||||
a_tensor,
|
||||
b_tensor,
|
||||
stream,
|
||||
self.impl.efc.jit.pack_arguments(*runtime_args),
|
||||
)
|
||||
|
||||
return CompiledArtifact(wrapped_launch, self)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
args: GemmArguments,
|
||||
compiled_artifact: CompiledArtifact,
|
||||
stream,
|
||||
workspace=None,
|
||||
) -> None:
|
||||
stream = to_cuda_stream(stream)
|
||||
|
||||
if args.epilogue is not None:
|
||||
epilogue_params = args.epilogue.parameters
|
||||
else:
|
||||
epilogue_params = [args.out]
|
||||
|
||||
compiled_gemm = compiled_artifact.compiled_obj
|
||||
self.cute_run(compiled_gemm, args.A, args.B, stream, *epilogue_params)
|
||||
|
||||
@staticmethod
|
||||
def generate_kernels(
|
||||
metadata_filter: Callable[[KernelMetadata], bool],
|
||||
epilogue_args: EpilogueArguments = None,
|
||||
cc: int = None,
|
||||
) -> list["PersistentDenseGemmEFCKernel"]:
|
||||
"""
|
||||
Returns a list of all possible configurations of PersistentDenseGemmEFCKernel that
|
||||
adhere to constraints passed in under kwargs.
|
||||
"""
|
||||
if cc is not None and cc not in [100, 101, 103]:
|
||||
return []
|
||||
|
||||
design_params = {
|
||||
"use_2cta_mma": [True, False],
|
||||
"tile_shape": [
|
||||
(M, N, 256) for M in [64, 128, 256] for N in [32, 64, 128, 256]
|
||||
],
|
||||
"cluster_shape": [
|
||||
(M, N, 1) for M in [1, 2, 4, 8, 16] for N in [1, 2, 4, 8, 16]
|
||||
],
|
||||
"use_tma_store": [True],
|
||||
}
|
||||
|
||||
if epilogue_args is not None:
|
||||
if not isinstance(epilogue_args, EpilogueArguments):
|
||||
return []
|
||||
|
||||
epilogue_metadata = EpilogueMetadata.from_args(epilogue_args)
|
||||
if not PersistentDenseGemmEFCKernel._valid_fusion(epilogue_metadata):
|
||||
return []
|
||||
else:
|
||||
epilogue_metadata = None
|
||||
|
||||
from itertools import product
|
||||
|
||||
param_names = list(design_params.keys())
|
||||
param_values = [design_params[name] for name in param_names]
|
||||
|
||||
kernel_list = []
|
||||
for operands in PersistentDenseGemmEFCKernel._metadata_operand_combinations():
|
||||
for values in product(*param_values):
|
||||
design = Sm100DesignMetadata(**dict(zip(param_names, values)))
|
||||
kernel_name = "cutedsl.PersistentDenseGemmEFCKernel_sm100_{layout}_A{A}_B{B}_out{out}_acc{acc}_{num_cta}cta_cluster{cluster}_tile{tile}{_tma_store}".format(
|
||||
layout=strides_to_layout_string(
|
||||
operands.A.stride, operands.B.stride, operands.out.stride
|
||||
),
|
||||
A=operands.A.dtype,
|
||||
B=operands.B.dtype,
|
||||
out=operands.out.dtype,
|
||||
acc=operands.accumulator_type,
|
||||
num_cta=("2" if design.use_2cta_mma else "1"),
|
||||
cluster=tuple_to_string(design.cluster_shape),
|
||||
tile=tuple_to_string(design.tile_shape),
|
||||
_tma_store="_tma_store" if design.use_tma_store else "",
|
||||
)
|
||||
|
||||
metadata = KernelMetadata(
|
||||
operands=operands,
|
||||
design=design,
|
||||
kernel_name=kernel_name,
|
||||
kernel_class=PersistentDenseGemmEFCKernel,
|
||||
min_cc=100,
|
||||
epilogue=epilogue_metadata,
|
||||
)
|
||||
|
||||
if PersistentDenseGemmEFCKernel._valid_metadata(
|
||||
metadata
|
||||
) and metadata_filter(metadata):
|
||||
kernel_list.append(
|
||||
PersistentDenseGemmEFCKernel(metadata)
|
||||
)
|
||||
|
||||
return kernel_list
|
||||
73
python/cutlass_api/cutlass_api/providers/cutedsl/kernel.py
Normal file
73
python/cutlass_api/cutlass_api/providers/cutedsl/kernel.py
Normal file
@@ -0,0 +1,73 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
from cutlass_api.config import GlobalOptions
|
||||
from cutlass_api.kernel import Kernel
|
||||
from cutlass_api.utils import TensorWrapper
|
||||
|
||||
import cutlass.cute as cute
|
||||
|
||||
|
||||
class CuteDslKernel(Kernel):
|
||||
"""
|
||||
Base class for all CuTe DSL kernels
|
||||
"""
|
||||
|
||||
def cute_compile(self, entry_point_fn, *fn_args):
|
||||
"""
|
||||
Compiles a kernel using CuTe DSL compile.
|
||||
|
||||
This method is intended to provider a provider-wide consistent implementation of compile() for
|
||||
all CuTe DSL kernels, respecting GlobalOptions.
|
||||
|
||||
:param entry_point_fn: The kernel function to compile
|
||||
:param fn_args: All arguments to pass to the kernel compilation
|
||||
:return: The compiled kernel object
|
||||
"""
|
||||
options = None
|
||||
if GlobalOptions().use_tvm_ffi:
|
||||
options = "--enable-tvm-ffi"
|
||||
|
||||
compile_args = [
|
||||
x.compile_time_tensor if isinstance(x, TensorWrapper) else x
|
||||
for x in fn_args
|
||||
]
|
||||
return cute.compile(entry_point_fn, *compile_args, options=options)
|
||||
|
||||
def cute_run(self, entry_point_fn, *fn_args):
|
||||
"""
|
||||
Extracts runtime tensors from TensorWrappers and runs the kernel.
|
||||
|
||||
:param entry_point_fn: The kernel function to run
|
||||
:param fn_args: All arguments to pass to the kernel run
|
||||
:return: The result of the kernel run
|
||||
"""
|
||||
run_args = [
|
||||
x.runtime_tensor if isinstance(x, TensorWrapper) else x for x in fn_args
|
||||
]
|
||||
return entry_point_fn(*run_args)
|
||||
43
python/cutlass_api/cutlass_api/providers/cutedsl/utils.py
Normal file
43
python/cutlass_api/cutlass_api/providers/cutedsl/utils.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
import cutlass.utils as utils
|
||||
|
||||
|
||||
def get_max_active_clusters(cluster_shape: tuple[int, int]) -> int:
|
||||
"""
|
||||
:param cluster_shape: The shape of the cluster.
|
||||
:type cluster_shape: tuple[int, int]
|
||||
|
||||
:returns: The maximum number clusters of the provided shape that can fit on the current GPU
|
||||
:rtype: int
|
||||
"""
|
||||
return utils.HardwareInfo().get_max_active_clusters(
|
||||
cluster_shape[0] * cluster_shape[1]
|
||||
)
|
||||
72
python/cutlass_api/cutlass_api/providers/provider.py
Normal file
72
python/cutlass_api/cutlass_api/providers/provider.py
Normal file
@@ -0,0 +1,72 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
from cutlass_api.arguments import EpilogueArguments
|
||||
from cutlass_api.kernel import Kernel
|
||||
from cutlass_api.metadata import KernelMetadata
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ProviderBase(Protocol):
|
||||
@classmethod
|
||||
def generate_kernels(
|
||||
cls,
|
||||
metadata_filter: Callable[[KernelMetadata], bool] | None = None,
|
||||
epilogue_args: EpilogueArguments | None = None,
|
||||
cc: int | None = None,
|
||||
) -> list[Kernel]:
|
||||
"""
|
||||
Return a list of kernels that support the type of the provided metadata.
|
||||
|
||||
:param metadata_filter: boolean filter function to apply to the metadata
|
||||
:type metadata_filter: Callable[[KernelMetadata], bool]
|
||||
:param epilogue_args: the epilogue arguments
|
||||
:type epilogue_args: EpilogueArguments
|
||||
:param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90
|
||||
:type cc: int
|
||||
|
||||
:return: list of all supported kernel configurations
|
||||
:rtype: list[Kernel]
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def register(cls, kernel_class: type[Kernel]) -> type[Kernel]:
|
||||
"""
|
||||
Registers a Kernel class as being able to be discovered through the given provider.
|
||||
|
||||
:param kernel_class: the Kernel class to register
|
||||
:type kernel_class: Type[Kernel]
|
||||
|
||||
:return: the registered kernel class
|
||||
:rtype: Type[Kernel]
|
||||
"""
|
||||
raise NotImplementedError
|
||||
61
python/cutlass_api/cutlass_api/status.py
Normal file
61
python/cutlass_api/cutlass_api/status.py
Normal file
@@ -0,0 +1,61 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Status:
|
||||
"""
|
||||
A simple status class to wrap an optional exception.
|
||||
"""
|
||||
|
||||
error: Exception | None = None
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return self.error is None
|
||||
|
||||
@classmethod
|
||||
def success(cls) -> Status:
|
||||
"""Create a successful status."""
|
||||
return cls()
|
||||
|
||||
@classmethod
|
||||
def fail(cls, error: str | Exception) -> Status:
|
||||
"""Create a failed status with an error."""
|
||||
if isinstance(error, str):
|
||||
error = ValueError(error)
|
||||
return cls(error=error)
|
||||
|
||||
def raise_on_error(self) -> None:
|
||||
"""Raise the stored exception if this status represents a failure."""
|
||||
if self.error is not None:
|
||||
raise self.error
|
||||
|
||||
60
python/cutlass_api/cutlass_api/typing.py
Normal file
60
python/cutlass_api/cutlass_api/typing.py
Normal file
@@ -0,0 +1,60 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
"""Type markers for CUTLASS API field annotations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class TensorLike(Protocol):
|
||||
"""Protocol for Tensor-like objects that support the DLPack protocol.
|
||||
|
||||
CUTLASS API supports tensor-like objects (e.g., torch.Tensor, cute.Tensor) that
|
||||
implement __dlpack__ and __dlpack_device__.
|
||||
"""
|
||||
|
||||
def __dlpack__(self, *, stream: Any = None) -> Any:
|
||||
"""Return a DLPack capsule representing the tensor data."""
|
||||
...
|
||||
|
||||
def __dlpack_device__(self) -> tuple[int, int]:
|
||||
"""Return the device type and device id as a tuple."""
|
||||
...
|
||||
|
||||
|
||||
class NumericLike(Protocol):
|
||||
"""Type marker for fields that accept numeric-like types.
|
||||
|
||||
Fields annotated with NumericLike accept the following:
|
||||
|
||||
- cutlass.Numeric
|
||||
- torch.dtype
|
||||
"""
|
||||
425
python/cutlass_api/cutlass_api/utils.py
Normal file
425
python/cutlass_api/cutlass_api/utils.py
Normal file
@@ -0,0 +1,425 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
"""
|
||||
Utilities for CUTLASS API.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
|
||||
from cutlass_api.config import GlobalOptions
|
||||
from cutlass_api.status import Status
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import cuda
|
||||
import torch
|
||||
|
||||
|
||||
def is_numpy_available() -> bool:
|
||||
"""Check if numpy is available."""
|
||||
return importlib.util.find_spec("numpy") is not None
|
||||
|
||||
|
||||
def is_torch_available() -> bool:
|
||||
"""Check if torch is available."""
|
||||
return importlib.util.find_spec("torch") is not None
|
||||
|
||||
|
||||
def is_numpy_tensor(inp) -> bool:
|
||||
"""Check if the input is a numpy tensor."""
|
||||
if is_numpy_available():
|
||||
import numpy as np
|
||||
|
||||
return isinstance(inp, np.ndarray)
|
||||
return False
|
||||
|
||||
|
||||
def is_torch_tensor(inp) -> bool:
|
||||
"""Check if the input is a torch tensor."""
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
return isinstance(inp, torch.Tensor)
|
||||
return False
|
||||
|
||||
|
||||
def _lazy_import(mod_name: str) -> Any:
|
||||
"""Internal utility to lazily import a module only when needed."""
|
||||
|
||||
class Lazy:
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
module = importlib.import_module(mod_name)
|
||||
return getattr(module, name)
|
||||
|
||||
return Lazy()
|
||||
|
||||
|
||||
def check_cuda_errors(result: list):
|
||||
"""
|
||||
Checks whether `result` contains a CUDA error raises the error as an exception, if so. Otherwise,
|
||||
returns the result contained in the remaining fields of `result`.
|
||||
|
||||
:param result: the results of the `cuda` method, consisting of an error code and any method results
|
||||
:type result: list
|
||||
|
||||
:return: non-error-code results from the `results` parameter
|
||||
"""
|
||||
# `result` is of the format : (cudaError_t, result...)
|
||||
err = result[0]
|
||||
if err.value:
|
||||
_lazy_import("cuda.cuda")
|
||||
cuda_bindings_runtime = _lazy_import("cuda.bindings.runtime")
|
||||
raise RuntimeError(
|
||||
f"CUDA error: {cuda_bindings_runtime.cudaGetErrorString(err)[1].decode('utf-8')}"
|
||||
)
|
||||
|
||||
if len(result) == 1:
|
||||
return None
|
||||
elif len(result) == 2:
|
||||
return result[1]
|
||||
else:
|
||||
return result[1:]
|
||||
|
||||
|
||||
def device_cc(device: int = 0) -> int:
|
||||
"""
|
||||
Returns the compute capability of the device with ID `device`.
|
||||
|
||||
:param device: ID of the device to query
|
||||
:type device: int
|
||||
|
||||
:return: compute capability of the queried device (e.g., 80 for SM80)
|
||||
:rtype: int
|
||||
"""
|
||||
_lazy_import("cuda.cuda")
|
||||
cuda_bindings_runtime = _lazy_import("cuda.bindings.runtime")
|
||||
deviceProp = check_cuda_errors(
|
||||
cuda_bindings_runtime.cudaGetDeviceProperties(device)
|
||||
)
|
||||
major = str(deviceProp.major)
|
||||
minor = str(deviceProp.minor)
|
||||
return int(major + minor)
|
||||
|
||||
|
||||
def is_device_cc_supported(supported_ccs: set[int]) -> Status:
|
||||
"""
|
||||
Fetch the device compute capability, and check if it is in supported ccs
|
||||
|
||||
:return: Status indicating success if device CC is in supported_ccs
|
||||
:rtype: Status
|
||||
"""
|
||||
try:
|
||||
cc = device_cc()
|
||||
if cc not in supported_ccs:
|
||||
return Status.fail(
|
||||
f"Compute capability {cc} is not in supported set {supported_ccs}."
|
||||
)
|
||||
return Status.success()
|
||||
except Exception as e:
|
||||
return Status.fail(e)
|
||||
|
||||
|
||||
def cutlass_type_from_torch_type(dtype) -> type[cutlass.Numeric]:
|
||||
"""
|
||||
Convert a torch dtype to a cutlass dtype.
|
||||
|
||||
:param dtype: The torch dtype to convert.
|
||||
|
||||
:return: The cutlass dtype.
|
||||
:rtype: type[cutlass.Numeric]
|
||||
"""
|
||||
import torch
|
||||
|
||||
torch_dtype_map = {
|
||||
torch.float64: cutlass.Float64,
|
||||
torch.float32: cutlass.Float32,
|
||||
torch.float16: cutlass.Float16,
|
||||
torch.bfloat16: cutlass.BFloat16,
|
||||
torch.int32: cutlass.Int32,
|
||||
torch.int8: cutlass.Int8,
|
||||
torch.uint8: cutlass.Uint8,
|
||||
torch.float8_e5m2: cutlass.Float8E5M2,
|
||||
torch.float8_e4m3fn: cutlass.Float8E4M3FN,
|
||||
torch.float8_e4m3fnuz: cutlass.Float8E4M3B11FNUZ,
|
||||
}
|
||||
|
||||
try:
|
||||
return torch_dtype_map[dtype]
|
||||
except KeyError:
|
||||
raise KeyError(f"Unsupported dtype: {dtype}")
|
||||
|
||||
|
||||
def to_cutlass_type(dtype) -> type[cutlass.Numeric]:
|
||||
"""
|
||||
Convert a dtype to a cutlass dtype.
|
||||
|
||||
:param dtype: The dtype to convert (e.g., torch.float32)
|
||||
|
||||
:return: The cutlass dtype.
|
||||
:rtype: type[cutlass.Numeric]
|
||||
"""
|
||||
if isinstance(dtype, type) and issubclass(dtype, cutlass.Numeric):
|
||||
return dtype
|
||||
|
||||
converters = []
|
||||
if is_torch_available():
|
||||
converters.append(cutlass_type_from_torch_type)
|
||||
|
||||
# Iterate through the available converters and return the first one that succeeds.
|
||||
for converter in converters:
|
||||
try:
|
||||
return converter(dtype)
|
||||
except KeyError:
|
||||
continue
|
||||
raise KeyError(f"Unsupported dtype: {dtype}")
|
||||
|
||||
|
||||
def to_cuda_stream(
|
||||
stream: cuda.bindings.driver.CUstream | torch.cuda.Stream,
|
||||
skip_if_ffi: bool = True,
|
||||
) -> cuda.bindings.driver.CUstream:
|
||||
"""
|
||||
Convert provided stream to a cuda.CUstream.
|
||||
|
||||
:param stream: The stream to convert.
|
||||
:type stream: Union[cuda.bindings.driver.CUstream, torch.cuda.Stream]
|
||||
:param skip_if_ffi: Skip the conversion if True and if TVM-FFI is enabled in GlobalOptions().
|
||||
:type skip_if_ffi: bool
|
||||
|
||||
:return: The converted stream.
|
||||
:rtype: cuda.bindings.driver.CUstream
|
||||
"""
|
||||
if skip_if_ffi and GlobalOptions().use_tvm_ffi:
|
||||
# TVM-FFI can directly handle streams of various types, including raw int handles
|
||||
return stream
|
||||
|
||||
cuda = _lazy_import("cuda.bindings.driver")
|
||||
if isinstance(stream, cuda.CUstream):
|
||||
return stream
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if isinstance(stream, torch.cuda.Stream):
|
||||
return cuda.CUstream(stream.cuda_stream)
|
||||
raise ValueError(f"Unsupported stream type: {type(stream)}")
|
||||
|
||||
|
||||
def add_batch_mode(
|
||||
tensor: cute.Tensor | torch.Tensor,
|
||||
) -> cute.Tensor | torch.Tensor:
|
||||
"""
|
||||
Adds a batch mode to the tensor.
|
||||
If the tensor is a torch.Tensor and has rank 2,
|
||||
it will be unsqueezed along the first dimension.
|
||||
|
||||
:param tensor: The tensor to add batch mode to.
|
||||
:type tensor: Union[cute.Tensor, "torch.Tensor"]
|
||||
|
||||
:return: The tensor with batch mode added.
|
||||
:rtype: Union[cute.Tensor, "torch.Tensor"]
|
||||
"""
|
||||
if is_torch_tensor(tensor):
|
||||
if tensor.dim() == 2:
|
||||
return tensor.unsqueeze(0)
|
||||
elif tensor.dim() < 2 or tensor.dim() > 3:
|
||||
raise ValueError(f"Expected 2-3 dimensions, got {tensor.dim()}")
|
||||
return tensor
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
def permute_batch_mode(
|
||||
tensor: cute.Tensor | torch.Tensor,
|
||||
) -> cute.Tensor | torch.Tensor:
|
||||
"""
|
||||
Permute the batch mode of the tensor.
|
||||
If the tensor is a torch.Tensor and has rank 3,
|
||||
it will be permuted along the first dimension.
|
||||
|
||||
:param tensor: The tensor to permute.
|
||||
:type tensor: Union[cute.Tensor, "torch.Tensor"]
|
||||
|
||||
:return: The tensor with batch mode permuted.
|
||||
:rtype: Union[cute.Tensor, "torch.Tensor"]
|
||||
"""
|
||||
if is_torch_tensor(tensor):
|
||||
if tensor.dim() != 3:
|
||||
raise ValueError(f"Expected 3 dimensions, got {tensor.dim()}")
|
||||
return tensor.permute([1, 2, 0])
|
||||
else:
|
||||
raise ValueError(f"Unsupported type: {type(tensor)}")
|
||||
|
||||
|
||||
def leading_dim(tensor) -> int:
|
||||
"""
|
||||
Get the leading dimension of a tensor. This is the first mode with
|
||||
stride of 1 and non-unit shape. If the has a single element, the leading
|
||||
dimension is 0. Modes with both stride of 1 and shape of 1 are treated
|
||||
as though they have stride of 0.
|
||||
|
||||
:param tensor: The tensor to get the leading dimension of.
|
||||
:type tensor: Union[cute.Tensor, "torch.Tensor"]
|
||||
|
||||
:return: The leading dimension of the tensor.
|
||||
:rtype: int
|
||||
"""
|
||||
if is_torch_tensor(tensor):
|
||||
if tensor.numel() == 1:
|
||||
return 0
|
||||
updated_stride = [
|
||||
s if sz != 1 else 0 for s, sz in zip(tensor.stride(), tensor.shape)
|
||||
]
|
||||
return updated_stride.index(1)
|
||||
else:
|
||||
raise ValueError(f"Unsupported tensor type: {type(tensor)}")
|
||||
|
||||
|
||||
def get_stride_order(stride: tuple[int, ...]) -> tuple[int, ...]:
|
||||
"""
|
||||
Returns the order of the stride of a tensor. For a stride of rank N,
|
||||
the dimension with the smallest stride will have stride order 0 and the
|
||||
dimension with the largest stride will have stride order N-1.
|
||||
|
||||
:param stride: The stride of the tensor.
|
||||
:type stride: tuple[int, ...]
|
||||
|
||||
:return: The order of the stride of the tensor.
|
||||
:rtype: tuple[int, ...]
|
||||
"""
|
||||
# The code below performs an argsort on the stride:
|
||||
# indices = range(len(stride))
|
||||
# Sort indices using comparison between stride[i] and stride[j] when
|
||||
# sorting indices i and j.
|
||||
return tuple(sorted(range(len(stride)), key=stride.__getitem__))
|
||||
|
||||
|
||||
class TensorWrapper:
|
||||
"""
|
||||
Wrapper class for supporting compilation and execution both with
|
||||
and without TVM-FFI.
|
||||
|
||||
When using TVM-FFI, one can pass a framework-level tensor (e.g., torch.Tensor)
|
||||
to the JIT function at run time, but not at compile time. At compile time, one
|
||||
must use a `_FakeTensor` to specify the tensor.
|
||||
|
||||
When not using TVM-FFI, one passes in a cute.Tensor at both compile and run time.
|
||||
|
||||
This class contains two key members:
|
||||
- runtime_tensor: The tensor to use at run time.
|
||||
- compile_time_tensor: The tensor to use at compile time.
|
||||
|
||||
Users of this class should access each of these underlying members for execution
|
||||
and runtime, respectively. Users of this class do not need to know whether TVM-FFI
|
||||
is enabled.
|
||||
"""
|
||||
|
||||
def __init__(self, tensor: Any):
|
||||
if isinstance(tensor, cute.Tensor):
|
||||
# Regardless of whether TVM-FFI is enabled, if the tensor passed in is a cute.Tensor,
|
||||
# it can be used as the runtime tensor and compile time tensor.
|
||||
self.runtime_tensor = tensor
|
||||
self.compile_time_tensor = tensor
|
||||
self._shape = tensor.shape
|
||||
self._stride = tensor.stride
|
||||
elif GlobalOptions().use_tvm_ffi:
|
||||
# If TVM-FFI is enabled, runtime tensor is set simply as the tensor passed in, but
|
||||
# we must make a fake tensor for compilation.
|
||||
self.runtime_tensor = tensor
|
||||
if is_torch_tensor(self.runtime_tensor):
|
||||
dtype = cutlass_type_from_torch_type(self.runtime_tensor.dtype)
|
||||
|
||||
rank = self.runtime_tensor.dim()
|
||||
self._stride = self.runtime_tensor.stride()
|
||||
stride_order = get_stride_order(self._stride)
|
||||
|
||||
shape = [cute.SymInt() for _ in range(rank)]
|
||||
shape[stride_order.index(0)] = cute.SymInt(divisibility=16)
|
||||
self._shape = tuple(self.runtime_tensor.shape)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported tensor type: {type(self.runtime_tensor)}"
|
||||
)
|
||||
|
||||
self.compile_time_tensor = cute.runtime.make_fake_compact_tensor(
|
||||
dtype, shape, stride_order=stride_order, assumed_align=16
|
||||
)
|
||||
else:
|
||||
# TVM-FFI is disabled and the tensor passed in is not a cute.Tensor,
|
||||
# We must convert it to a cute.Tensor
|
||||
self.runtime_tensor = from_dlpack(
|
||||
tensor, assumed_align=16
|
||||
).mark_layout_dynamic(leading_dim(tensor))
|
||||
|
||||
self._shape = self.runtime_tensor.shape
|
||||
self._stride = self.runtime_tensor.stride
|
||||
|
||||
# Since the runtime tensor is now a cute.Tensor, we can use it at
|
||||
# compile time as well
|
||||
self.compile_time_tensor = self.runtime_tensor
|
||||
|
||||
@property
|
||||
def element_type(self) -> type[cutlass.Numeric]:
|
||||
return self.compile_time_tensor.element_type
|
||||
|
||||
@property
|
||||
def shape(self) -> tuple[int, ...]:
|
||||
return self._shape
|
||||
|
||||
@property
|
||||
def stride(self) -> tuple[int, ...]:
|
||||
return self._stride
|
||||
|
||||
|
||||
def strides_to_layout_string(*strides: list[tuple[int, ...]]) -> str:
|
||||
"""
|
||||
Convert list[stride tuples] to layout string ('t' for row-major, 'n' for col-major).
|
||||
Example: ((0, 1, 0), (0, 0, 1), (0, 0, 1)) -> "ntt"
|
||||
"""
|
||||
row_major_stride = (0, 0, 1)
|
||||
col_major_stride = (0, 1, 0)
|
||||
|
||||
layout_string_map = {
|
||||
row_major_stride: "t",
|
||||
col_major_stride: "n",
|
||||
}
|
||||
|
||||
return "".join(layout_string_map[s] for s in strides)
|
||||
|
||||
|
||||
def tuple_to_string(tuple: tuple[int, ...]) -> str:
|
||||
"""
|
||||
Convert a tuple of integers to an 'x'-separated string (e.g., (2, 3, 4) -> '2x3x4').
|
||||
"""
|
||||
return "x".join(str(x) for x in tuple) if len(tuple) > 1 else str(tuple[0])
|
||||
558
python/cutlass_api/examples/000_gemm.ipynb
Normal file
558
python/cutlass_api/examples/000_gemm.ipynb
Normal file
@@ -0,0 +1,558 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "3dd45ef2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Basic GEMM using CUTLASS Python API"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4709aa60",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The CUTLASS API provides a consistent, uniform interface for discovering, compiling, and running GPU kernels from various DSL sources.\n",
|
||||
"\n",
|
||||
"This notebook walks through a minimal GEMM (Generalized Matrix-Matrix Multiplication) example, and introduces the core concepts of the API."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "f878d960-d175-4d84-b978-88afbd318850",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"import cutlass\n",
|
||||
"\n",
|
||||
"import cutlass_api\n",
|
||||
"\n",
|
||||
"if not (status := cutlass_api.utils.is_device_cc_supported({100, 103})):\n",
|
||||
" print(\n",
|
||||
" f\"This notebook requires a GPU with compute capability 100 or 103.\\n{status.error}\"\n",
|
||||
" )\n",
|
||||
" import sys\n",
|
||||
"\n",
|
||||
" sys.exit(0)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "db91dab6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Running your first kernel"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7b7b87b0",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Setting up arguments\n",
|
||||
"\n",
|
||||
"CUTLASS API has first-class support for PyTorch tensors. We start by creating torch tensors that will be operands to a matrix multiplication."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "f550c4ea",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"M, N, K, L = 128, 256, 64, 2\n",
|
||||
"ab_type = torch.float16\n",
|
||||
"out_type = torch.float32\n",
|
||||
"acc_type = torch.float32\n",
|
||||
"\n",
|
||||
"A = torch.randint(-1, 2, (L, M, K), device=\"cuda\", dtype=ab_type)\n",
|
||||
"B = torch.randint(-1, 2, (L, K, N), device=\"cuda\", dtype=ab_type)\n",
|
||||
"out = torch.empty((L, M, N), device=\"cuda\", dtype=out_type)\n",
|
||||
"\n",
|
||||
"reference = (A @ B).to(out.dtype)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "6b6cb805",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We then create a `GemmArguments` object. This object specifies:\n",
|
||||
"1. what logical operation do we want to perform (a GEMM)\n",
|
||||
"2. on which operands we want to perform that operation (`A`, `B`, `out` as declared above)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "b57690df",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"args = cutlass_api.arguments.GemmArguments(A=A, B=B, out=out, accumulator_type=acc_type)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "67e5ddcf",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Kernel discovery\n",
|
||||
"\n",
|
||||
"We now need to find kernels that can perform the operation we expressed in `args`.\n",
|
||||
"\n",
|
||||
"The simplest way to do so is to use `get_kernels(args)`. It searches a set of kernels pre-registered in the library, and returns the subset of those kernels which can successfully run our given `args`.\n",
|
||||
"\n",
|
||||
"Any of these kernels will be functionally equivalent -- they may have different design or performance characteristics. We arbitrarily pick the first of the returned kernels to execute here"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "9872ad66",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"kernels = cutlass_api.get_kernels(args)\n",
|
||||
"assert kernels, \"No kernels found for the given arguments!\"\n",
|
||||
"\n",
|
||||
"kernel = kernels[0]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4c17693e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Run the kernel\n",
|
||||
"\n",
|
||||
"Running the kernel is as simple as `kernel.run(args)`.\n",
|
||||
"\n",
|
||||
"This implicitly JIT-compiles the kernel, and launches it on the GPU device using our given arguments."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "baf4588a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"kernel.run(args)\n",
|
||||
"\n",
|
||||
"torch.testing.assert_close(out, reference)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4d7ad85b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"One can also explicitly compile the kernel and pass this in to `kernel.run` to avoid\n",
|
||||
"JIT compilation on future invocations. Additional details related to this will be\n",
|
||||
"described below."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "06f9f844",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"artifact = kernel.compile(args)\n",
|
||||
"kernel.run(args, compiled_artifact=artifact)\n",
|
||||
"torch.testing.assert_close(out, reference)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "630e9e4b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"---\n",
|
||||
"\n",
|
||||
"---\n",
|
||||
"\n",
|
||||
"### Understanding the core interfaces"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2d8b8e94",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### 1. `RuntimeArguments` / `GemmArguments`\n",
|
||||
"\n",
|
||||
"`RuntimeArguments` describe the operation a user wants to perform, and all the runtime operands or other runtime parameters needed for it. \n",
|
||||
"This includes primary runtime operands to the operation, as well as any custom epilogue fusions and runtime performance knobs.\n",
|
||||
"\n",
|
||||
"We provide builtin subtypes of `RuntimeArguments` for common operations (e.g. GEMM, Elementwise ops; more later).\n",
|
||||
"\n",
|
||||
"For instance, `GemmArguments` is a type of `RuntimeArguments`:\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"@dataclass\n",
|
||||
"class GemmArguments(RuntimeArguments):\n",
|
||||
" A: TensorLike\n",
|
||||
" B: TensorLike\n",
|
||||
" out: TensorLike\n",
|
||||
" accumulator_type: NumericLike\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"`GemmArguments` conveys:\n",
|
||||
"* We want to perform a dense GEMM operation (`out = A @ B`)\n",
|
||||
"* We want to perform it for operands in `A, B, out`, with intermediate results stored as `accumulator_type`\n",
|
||||
"* We can optionally set a custom epilogue that is fused on top of the base GEMM. Some kernels also support some runtime performance controls which can be specified here. These will be discussed in detail in other tutorials.\n",
|
||||
"\n",
|
||||
"It is a kernel-agnostic way to specify the desired functionality.\n",
|
||||
"\n",
|
||||
"`RuntimeArguments` can be constructed from any `TensorLike` object. This includes `torch.Tensor`, `cute.Tensor`, or any other DLPack-compatible tensors."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e7eda0dd",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### 2. Kernel Discovery\n",
|
||||
"\n",
|
||||
"There are several kernels available in CUTLASS DSLs that are registered with, and discoverable via, the CUTLASS API.\n",
|
||||
"\n",
|
||||
"This includes kernels for various operations (GEMM, Elementwise operations, ...), which implement various algorithms & architecture features. Within the same implementation, there are several instances or configurations of it with different combinations of operand types, layouts, tile sizes, etc.\n",
|
||||
"\n",
|
||||
"In the previous step, we used `GemmArguments` to specify our desired GEMM in a kernel-agnostic way. Now we find kernels that can fulfill that functionality. A subset of the available kernels will perform GEMM, and a subset of _those_ will support the properties of specific operands we are currently using."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "3b737131",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"A total of 107616 kernel instances are available.\n",
|
||||
"Of these, 350 support the given arguments.\n",
|
||||
"Picked kernel with name: cutedsl.PersistentDenseGemmKernel_sm100_ttt_AFloat16_BFloat16_outFloat32_accFloat32_2cta_cluster2x1x1_tile128x32x256_tma_store\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# get_kernels() fetches all kernels when called without args\n",
|
||||
"all_kernels = cutlass_api.get_kernels()\n",
|
||||
"print(f\"A total of {len(all_kernels)} kernel instances are available.\")\n",
|
||||
"\n",
|
||||
"# we can limit the search to kernels supporting given args\n",
|
||||
"kernels = cutlass_api.get_kernels(args)\n",
|
||||
"print(f\"Of these, {len(kernels)} support the given arguments.\")\n",
|
||||
"\n",
|
||||
"kernel = kernels[0]\n",
|
||||
"print(f\"Picked kernel with name: {kernel.metadata.kernel_name}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "252a4d38",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### 3. `Kernel` execution"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "574d004b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Once we have selected a kernel, we are now ready to execute it. We previously showed the simplest way to do this is `kernel.run(args)`.\n",
|
||||
"\n",
|
||||
"This method does the following:\n",
|
||||
"* verify that the kernel supports the given `args`\n",
|
||||
"* JIT-compile the kernel\n",
|
||||
"* launch the compiled kernel function\n",
|
||||
"\n",
|
||||
"Users can do these steps individually for more control:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e8945aa6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"* `kernel.supports(args)` checks if the kernel supports the given `args`\n",
|
||||
" * this is relevant if the kernel was not picked just for these `args`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "159fd610",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"supported = kernel.supports(args)\n",
|
||||
"assert supported"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "948689a8",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"If the arguments are not supported by this kernel, `supports` returns a `Status` object explaining the error."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "2cfc9ea7",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Operand `A` is unsupported: Expected element type Float16, got BFloat16\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"unsupported_args = cutlass_api.arguments.GemmArguments(\n",
|
||||
" A=A.to(torch.bfloat16), B=B, out=out, accumulator_type=acc_type\n",
|
||||
")\n",
|
||||
"if not (status := kernel.supports(unsupported_args)):\n",
|
||||
" print(status.error)\n",
|
||||
"\n",
|
||||
"assert not status"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c2db8f20",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"* `kernel.compile(args)` compiles the kernel, and returns a `CompiledArtifact`\n",
|
||||
"\n",
|
||||
"This compiled artifact is a lightweight wrapper over the result of compiling a kernel (e.g., via `cute.compile()`).\n",
|
||||
"\n",
|
||||
"For just-in-time compilation, we can use the compiled artifact straightaway.\n",
|
||||
"In the future, we will support optionally serializing it for ahead-of-time compilation and deserialized in a different context.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "02e79eb8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"compiled_artifact = kernel.compile(args)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4dfb8d51",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"* `kernel.run(args)` launches the compiled kernel function. This example uses:\n",
|
||||
" * the precompiled artifact\n",
|
||||
" * a custom stream to launch to\n",
|
||||
" * bypasses the supports check already performed above (`assume_supported_args=True`)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "02398bf0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# zero the output to avoid testing stale output\n",
|
||||
"out.zero_()\n",
|
||||
"\n",
|
||||
"kernel.run(\n",
|
||||
" args,\n",
|
||||
" compiled_artifact,\n",
|
||||
" stream=torch.cuda.Stream(),\n",
|
||||
" assume_supported_args=True,\n",
|
||||
")\n",
|
||||
"torch.testing.assert_close(out, reference)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f67eeb8f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Some kernels may also require a device \"workspace\". This is an additional buffer needed by some kernels for book-keeping, temporary results, etc.\n",
|
||||
"Its size can be queried using `kernel.get_workspace_size(args)`. Most kernels will have a workspace size of 0.\n",
|
||||
"If a kernel does have a non-zero workspace size, an additional buffer of at least that size must be provided. Without it, the kernel behavior is undefined."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "5b2e2d56",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"workspace_size = kernel.get_workspace_size(args)\n",
|
||||
"workspace = torch.empty(workspace_size, device=\"cuda\", dtype=torch.int8)\n",
|
||||
"\n",
|
||||
"out.zero_()\n",
|
||||
"kernel.run(args, compiled_artifact, stream=torch.cuda.Stream(), workspace=workspace)\n",
|
||||
"torch.testing.assert_close(out, reference)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "baffaf12",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Advanced: Filtering on Metadata"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "86dd521a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Using `RuntimeArguments` to search for supporting kernels is a convenient way to discover kernels: users directly specify their desired functionality, and `get_kernels()` finds the supporting kernels.\n",
|
||||
"It covers all logical operands of any operation, as well as (in later examples) epilogue fusions, and performance controls.\n",
|
||||
"\n",
|
||||
"However, there may be cases where users want more advanced ways to query kernels. These could be:\n",
|
||||
"* when the desired properties may not be expressed in runtime controls\n",
|
||||
" * the simplest scenario may be if you're searching searching for a kernel with a specific name, a specific class, etc.\n",
|
||||
" * searching for kernel's static properties such as tile size, cluster size, etc.\n",
|
||||
"* when the `RuntimeArguments` are not available or you want to generate & pre-compile a broader set of kernels\n",
|
||||
"\n",
|
||||
"For such cases, we provide a more advanced filtering based on `KernelMetadata`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ef54ae50",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"`KernelMetadata` captures a wide variety of properties of a `Kernel`.\n",
|
||||
"\n",
|
||||
"These are properties of a kernel's functional support (like operand types, layouts, alignments), as well as architectural/design choices & performance characteristics (like tilze size, scheduling characteristics).\n",
|
||||
"\n",
|
||||
"Different kernels may use different sub-classes of `metadata.operands`, `metadata.design`, `metadata.epilogue` for flexibility, which can also identify their characteristics.\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"@dataclass\n",
|
||||
"class KernelMetadata:\n",
|
||||
" kernel_name: str\n",
|
||||
" kernel_class: type[\"Kernel\"]\n",
|
||||
" min_cc: int\n",
|
||||
" operands: OperandsMetadata\n",
|
||||
" design: DesignMetadata | None = None\n",
|
||||
" epilogue: EpilogueMetadata | None = None\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f9943888",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Every unique kernel instance can be distinguished by its metadata.\n",
|
||||
"It can be used in filtering for kernels in addition to the `RuntimeArguments`, by providing a custom `metadata_filter`.\n",
|
||||
"\n",
|
||||
"Here, we get all kernels that support `args`, and have `metadata.design` of type `Sm100DesignMetadata`.\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"id": "8717ac89",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Found 350 kernels which support args & have Sm100DesignMetadata\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"kernels = cutlass_api.get_kernels(\n",
|
||||
" args,\n",
|
||||
" metadata_filter=lambda metadata: isinstance(\n",
|
||||
" metadata.design, cutlass_api.metadata.Sm100DesignMetadata\n",
|
||||
" ),\n",
|
||||
")\n",
|
||||
"print(f\"Found {len(kernels)} kernels which support args & have Sm100DesignMetadata\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1d1f9124",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can construct more advanced filters by leveraging duck-typing.\n",
|
||||
"Additionally, we can get all the kernels that match our filter, rather than supporting a fully-defined set of arguments.\n",
|
||||
"This could be useful to pre-generate large set of kernels not targeted to any one problem."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"id": "a76ec20f",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Found 9400 matching kernels\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"def a_more_complex_filter(metadata: cutlass_api.metadata.KernelMetadata) -> bool:\n",
|
||||
" \"\"\"\n",
|
||||
" Find all GEMM kernels that support Float16 A and 2-CTA MMA\n",
|
||||
" \"\"\"\n",
|
||||
" # Only look at GEMM kernels\n",
|
||||
" if not isinstance(metadata.operands, cutlass_api.metadata.GemmOperandsMetadata):\n",
|
||||
" return False\n",
|
||||
" # Only look at kernels with A-type F16\n",
|
||||
" if metadata.operands.A.dtype != cutlass.Float16:\n",
|
||||
" return False\n",
|
||||
" # Only look at kernels with tile_shape[0] == 128\n",
|
||||
" if getattr(metadata.design, \"tile_shape\", [None])[0] != 128:\n",
|
||||
" return False\n",
|
||||
" return True\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Look ma, no args! Fetch all kernels that match the filter,\n",
|
||||
"# instead of supporting a complete set of args\n",
|
||||
"kernels = cutlass_api.get_kernels(\n",
|
||||
" args=None,\n",
|
||||
" metadata_filter=a_more_complex_filter,\n",
|
||||
")\n",
|
||||
"print(f\"Found {len(kernels)} matching kernels\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
518
python/cutlass_api/examples/001_gemm_with_fused_epilogue.ipynb
Normal file
518
python/cutlass_api/examples/001_gemm_with_fused_epilogue.ipynb
Normal file
@@ -0,0 +1,518 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a31330d3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Custom epilogue fusions for GEMMs\n",
|
||||
"\n",
|
||||
"Note: this notebook requires a GPU with compute capability 100 or 103:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "bb450878",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import cutlass_api\n",
|
||||
"\n",
|
||||
"if not (status := cutlass_api.utils.is_device_cc_supported({100, 103})):\n",
|
||||
" print(f\"This notebook requires a GPU with compute capability 100 or 103.\\n{status.error}\")\n",
|
||||
" import sys\n",
|
||||
"\n",
|
||||
" sys.exit(0)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "154e9d59",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The CUTLASS API provides flexible epilogue fusion support by allowing for the specification of an epilogue via high-level tensor operations that one would like to compose with an operation.\n",
|
||||
"\n",
|
||||
"For those familiar with the legacy CUTLASS Python API's [epilogue visitor tree frontend](https://github.com/NVIDIA/cutlass/blob/a2439551c765c5393aebe557ee75d3a0412d2211/examples/python/deprecated/04_epilogue_visitor.ipynb), much of the interface is shared.\n",
|
||||
"\n",
|
||||
"The CUTLASS API enables one to express an epilogue using a function operating at the `torch.Tensor`-level, and has tooling to automatically add this to kernels supporting the provided function. \n",
|
||||
"\n",
|
||||
"For example, in PyTorch one might write the following to compute a GEMM + epilogue:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e6d77d53",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"torch.manual_seed(2025)\n",
|
||||
"\n",
|
||||
"L, M, N, K = 1, 1024, 1024, 1024\n",
|
||||
"A = torch.randn(L, M, K, device=\"cuda\", dtype=torch.float16)\n",
|
||||
"B = torch.randn(L, K, N, device=\"cuda\", dtype=torch.float16)\n",
|
||||
"C = torch.randn(L, M, N, device=\"cuda\", dtype=torch.float16)\n",
|
||||
"\n",
|
||||
"def my_epilogue(accum, C, alpha, beta, extra_scalar):\n",
|
||||
" Aux = (alpha * accum) + (beta * C)\n",
|
||||
" D = extra_scalar * Aux\n",
|
||||
" return D, Aux\n",
|
||||
"\n",
|
||||
"alpha, beta, extra_scalar = 1.0, 2.0, 0.5\n",
|
||||
"D, Aux = my_epilogue(A @ B, C, alpha, beta, extra_scalar)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "66ee4dd1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The CUTLASS API allows the same epilogue function `my_epilogue` to be used in GEMMs provided by the API.\n",
|
||||
"\n",
|
||||
"To do so, one defines `EpilogueArguments` consisting of the epilogue function to compute (or a string representation of it) along with arguments corresponding to each input and output of the function (except for `accum`):"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f079d9d6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import cutlass_api\n",
|
||||
"from cutlass_api.arguments import GemmArguments, EpilogueArguments\n",
|
||||
"\n",
|
||||
"# Allocate buffers for D and Aux\n",
|
||||
"D_, Aux_ = [torch.empty((L, M, N), device=\"cuda\", dtype=torch.float16) for _ in range(2)]\n",
|
||||
"\n",
|
||||
"epi_args = EpilogueArguments(my_epilogue, C=C, alpha=alpha, beta=beta, extra_scalar=extra_scalar, D=D_, Aux=Aux_)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "97ef8e8a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"These arguments can be added to `GemmArguments` and passed in to `get_kernels()` for use when retrieving compatible kernels:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "60215c4e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
|
||||
"kernels = cutlass_api.get_kernels(args)\n",
|
||||
"assert len(kernels) > 0\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b0a7f9a2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Each of the kernels returned by `get_kernels` can be compiled and executed just the same with these new arguments, as it was in examples without\n",
|
||||
"epilogue fusion. For example, using the first kernel:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "150f3296",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"kernels[0].run(args)\n",
|
||||
"\n",
|
||||
"torch.testing.assert_close(D, D_)\n",
|
||||
"torch.testing.assert_close(Aux, Aux_)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f1a826e3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## How the epilogue fusion API works\n",
|
||||
"To support specifying an epilogue via a Python function, a kernel needs some mechanism to:\n",
|
||||
"1. Detect the operations in the epilogue function\n",
|
||||
"2. Determine if the kernel can support the operations\n",
|
||||
"3. Emit code to perform these operations within the kernel\n",
|
||||
"\n",
|
||||
"Step 1 listed above does not depend on the kernel and its implementation (e.g., DSL), while steps 2 and 3 depend on the kernel and/or its implementation.\n",
|
||||
"\n",
|
||||
"Thus, the CUTLASS API separates these components so that step 1 takes place at the API level and steps 2 and 3 take place in the kernel. This process is visualized below. We will walk through each step in greater detail.\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
" +------------------------------------+\n",
|
||||
" | def epi(accum, alpha, beta, C): |\n",
|
||||
" | D = (accum * alpha) + (beta * C) | 1. Define epilogue via a Python function\n",
|
||||
" | return D |\n",
|
||||
" +------------------------------------+\n",
|
||||
" |\n",
|
||||
" |\n",
|
||||
" |\n",
|
||||
" GemmArguments(..., 2. Pass epilogue function, operands, and outputs\n",
|
||||
" epilogue=EpilogueArguments( to EpilogueArguments constructor,\n",
|
||||
" epi, alpha=alpha, beta=beta, C=C)) and add this to the GemmArguments. Under the\n",
|
||||
" | hood, this parses the Python AST of the\n",
|
||||
" | epilogue function to produce a DAG of load,\n",
|
||||
" | store, and compute nodes.\n",
|
||||
" V\n",
|
||||
" +-----------------------------------------+ \n",
|
||||
" | Intermediate DAG representation |\n",
|
||||
" | =============================== |\n",
|
||||
" | |\n",
|
||||
" | Store() |\n",
|
||||
" | | |\n",
|
||||
" | Add() |\n",
|
||||
" | / \\ |\n",
|
||||
" | / \\ |\n",
|
||||
" | / \\ |\n",
|
||||
" | Mul() Mul() |\n",
|
||||
" | / \\ / \\ |\n",
|
||||
" | AccFetch() | Load(C) \\ |\n",
|
||||
" | | \\ |\n",
|
||||
" | Load(alpha) Load(beta) |\n",
|
||||
" | |\n",
|
||||
" +-----------------------------------------+\n",
|
||||
" / | \\\n",
|
||||
" / | \\ 3. Individual kernel classes use the DAG representation\n",
|
||||
" / | \\ to determine if the kernel class supports the DAG.\n",
|
||||
" Kernel 0 Kernel 1 Kernel 2 If so, the kernel class emits DSL-level operations\n",
|
||||
" epilogue epilogue epilogue needed to compute the epilogue DAG alongside the\n",
|
||||
" emitter emitter emitter basic operation of the kernel (e.g., GEMM).\n",
|
||||
" | | |\n",
|
||||
" | | |\n",
|
||||
" V V V\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"### Defining an epilogue via a Python function\n",
|
||||
"Epilogue fusion patterns are defined by users in Python functions that perform Tensor-level operations -- using a `torch.Tensor` (for example) resulting from matrix multiplication, the function must be able to compute the desired results of the epilogue.\n",
|
||||
"\n",
|
||||
"The structure of these functions is as follows:\n",
|
||||
"```python\n",
|
||||
"def custom_epi_name(accum, *args) -> Union[TensorType, tuple[TensorType]]:\n",
|
||||
" \"\"\"\n",
|
||||
" :param accum: result of matrix multiplication, convolution, etc. before the epilogue\n",
|
||||
" :type accum: TensorType\n",
|
||||
" :param args: additional arguments to be used in the epilogue (e.g., aux tensors)\n",
|
||||
" :type args: list[Union[TensorType, ScalarType]]\n",
|
||||
"\n",
|
||||
" :returns: at least one tensor resulting from the operation of the epilogue\n",
|
||||
" :rtype: Union[TensorType, tuple[TensorType]]\n",
|
||||
" \"\"\"\n",
|
||||
" # Do some compute\n",
|
||||
" return D # and potentially other values\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"The user defines a custom epilogue via a Python function that **must** do at least the following:\n",
|
||||
"1. Take in a first positional argument named `accum` that represents the result of operation just before the epilogue is to be performed. For example, in a GEMM, `accum = A @ B`.\n",
|
||||
"2. Return at least one tensor that results from computing the epilogue. Currently, the return list must contain at least one output named `D`, though this constraint may be loosened in the future.\n",
|
||||
"\n",
|
||||
"Each additional argument following `accum` in the function definition is expected to be either a Tensor or scalar to be loaded. Each variable in the return statement represents a Tensor or scalar to be stored. The underlying implementation of the epilogue in the kernel will determine how operands are loaded and stored.\n",
|
||||
"\n",
|
||||
"Compute operations are represented in static single assignment (SSA) form.\n",
|
||||
"This means that each variable can be assigned exactly once.\n",
|
||||
"Operations currently supported ares:\n",
|
||||
"* Tensor-tensor elementwise addition, subtraction, multiplication, and division\n",
|
||||
"* Scalar broadcasts via addition, subtraction, multiplication, and division\n",
|
||||
"* Predefined elementwise activation functions (e.g., ReLU, sigmoid, tanh)\n",
|
||||
"\n",
|
||||
"Operations that are not yet supported include:\n",
|
||||
"* Row/column broadcasts (planned to be added soon)\n",
|
||||
"* Reductions (planned to be added soon)\n",
|
||||
"* Binary minimum and maximum functions (planned to be added soon)\n",
|
||||
"If attempting to use these operations will result in no kernels being found in the call to `get_kernels`.\n",
|
||||
"\n",
|
||||
"Violations to SSA or use of unexpected operators will be flagged with an exception when parsing the AST of the custom epilogue.\n",
|
||||
"\n",
|
||||
"Examples of epilogues fitting these patterns are given below. We will show full, runnable examples at the end of this notebook.\n",
|
||||
"```python\n",
|
||||
"def relu_aux_store(accum, alpha, C):\n",
|
||||
" # Note that the function definition itself does not indicate the types and\n",
|
||||
" # ranks of alpha and C. Thus, one cannot tell whether the epilogue is performing\n",
|
||||
" # broadcasts or elementwise operations until actual arguments or metadata are\n",
|
||||
" # provided to the epilogue. See below for details.\n",
|
||||
" F = (accum * alpha) + (C * 2.0) # Constant beta of 2.0\n",
|
||||
" D = relu(F)\n",
|
||||
" return D, F\n",
|
||||
"\n",
|
||||
"def aux_normalize(accum, aux):\n",
|
||||
" D = accum / aux\n",
|
||||
" return D\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"Additional information about each operand and output must be provided by the user when constructing `EpilogueArguments`, as we will discuss below. This additional information is necessary for fully defining the operations being performed -- without knowledge of whether `alpha` is a scalar or a Tensor, we cannot determine whether multiplication by `alpha` is a broadcasted or elementwise operation.\n",
|
||||
"\n",
|
||||
"### Constructing epilogue arguments\n",
|
||||
"`EpilogueArguments` encapsulate the arguments needed to determine the functional operation of a fused epilogue.\n",
|
||||
"\n",
|
||||
"A user must provide in the construction of `EpilogueArguments` tensors for all operands and outputs of the epilogue. However, unlike arguments for basic operations (e.g., GEMM), the full set of operands needed to be specified for an epilogue pattern depends upon the custom epilogue defined by the user.\n",
|
||||
"\n",
|
||||
"Therefore, `EpilogueArguments` is defined generically as taking in an `epilogue_fn` and additional `kwargs`. Under the hood, the AST for `epilogue_fn` is parsed to determine the operands and outputs of the epilogue. The user is required to provide in `kwargs` Tensors or scalars for all operands and outputs in the provided epilogue.\n",
|
||||
"\n",
|
||||
"For example, with an epilogue of:\n",
|
||||
"```python\n",
|
||||
"def my_epi(accum, alpha, C, beta):\n",
|
||||
" F = (accum * alpha) + (C * beta)\n",
|
||||
" D = relu(F)\n",
|
||||
" return D, F\n",
|
||||
"```\n",
|
||||
"A user would need to construct epilogue arguments as follows:\n",
|
||||
"```python\n",
|
||||
"epi_args = EpilogueArguments(my_epi, alpha=..., C=..., beta=..., D=..., F=...)\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"After verifying that all required operands and outputs are present, the constructor to `EpilogueArguments` will perform additional passes on the AST of `epilogue_fn` using the provided inputs to generate an internal DAG representing the epilogue. This DAG structure is attached to `EpilogueArguments` for use as they are passed through a call to `get_kernels`.\n",
|
||||
"\n",
|
||||
"### Discovering kernels that support the epilogue pattern\n",
|
||||
"\n",
|
||||
"The call to `get_kernels(args)` will return any kernels that support the provided `GemmArguments`.\n",
|
||||
"Since the `GemmArguments` constructed above now include `EpilogueArguments`, returned kernels must support the provided epilogue.\n",
|
||||
"\n",
|
||||
"Under the hood of `get_kernels()`, each `Kernel` class will determine in its `generate_kernels()` method whether it supports the provided `EpilogueArguments`.\n",
|
||||
"It can do so by traversing the DAG that resulted from the construction of `EpilogueArguments` to find the operations that compose the epilogue.\n",
|
||||
"Assuming that the `Kernel` can support the DAG, it must then add to the source for the kernel any operations needed to support the DAG.\n",
|
||||
"An example of how this is done generically for an SM100 CuTe DSL GEMM is provided in `sm100_static_persistent_efc.py`.\n",
|
||||
"\n",
|
||||
"## Example epilogues\n",
|
||||
"We now provide various examples of adding custom epilogues to GEMM kernels targeting SM100. A broader set of epilogue examples are available in `test_gemm_epilogue_fusion.py`.\n",
|
||||
"\n",
|
||||
"### Auxiliary input and output tensors"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "171ac178",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from cutlass_api.fusion.activation import relu\n",
|
||||
"\n",
|
||||
"def relu_aux_store(accum, alpha, C):\n",
|
||||
" F = (accum * alpha) + (C * 2.0) # Constant beta\n",
|
||||
" D = relu(F)\n",
|
||||
" return D, F\n",
|
||||
"\n",
|
||||
"C = torch.randn((L, M, N), device=\"cuda\", dtype=torch.float16)\n",
|
||||
"alpha = 3.0\n",
|
||||
"D = torch.empty((L, M, N), device=\"cuda\", dtype=torch.float16)\n",
|
||||
"F = torch.empty((L, M, N), device=\"cuda\", dtype=torch.float16)\n",
|
||||
"\n",
|
||||
"epi_args = EpilogueArguments(relu_aux_store, alpha=alpha, C=C, D=D, F=F)\n",
|
||||
"args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
|
||||
"kernels = cutlass_api.get_kernels(args, cc=100)\n",
|
||||
"assert len(kernels) > 0\n",
|
||||
"kernels[0].run(args)\n",
|
||||
"\n",
|
||||
"D_ref, F_ref = relu_aux_store(A @ B, alpha, C)\n",
|
||||
"\n",
|
||||
"torch.testing.assert_close(D, D_ref)\n",
|
||||
"torch.testing.assert_close(F, F_ref)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f947b403",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Keyword functions and returning accumulator"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "62c2b49b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def relu_scale_return_acc(accum, alpha, beta, C, scale):\n",
|
||||
" F = relu((accum * alpha) + (C * beta))\n",
|
||||
" D = F * scale\n",
|
||||
" return D, F, accum\n",
|
||||
"\n",
|
||||
"C = torch.randn((L, M, N), device=\"cuda\", dtype=torch.float16)\n",
|
||||
"alpha = 1.0\n",
|
||||
"beta = 2.0\n",
|
||||
"scale = 0.5\n",
|
||||
"D = torch.empty((L, M, N), device=\"cuda\", dtype=torch.float16)\n",
|
||||
"F = torch.empty((L, M, N), device=\"cuda\", dtype=torch.float16)\n",
|
||||
"accum = torch.empty((L, M, N), device=\"cuda\", dtype=torch.float32)\n",
|
||||
"\n",
|
||||
"epi_args = EpilogueArguments(relu_scale_return_acc, alpha=alpha, beta=beta, C=C, scale=scale, D=D, F=F, accum=accum)\n",
|
||||
"args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
|
||||
"kernels = cutlass_api.get_kernels(args, cc=100)\n",
|
||||
"assert len(kernels) > 0\n",
|
||||
"kernels[0].run(args)\n",
|
||||
"\n",
|
||||
"D_ref, F_ref, accum_ref = relu_scale_return_acc(A @ B, alpha, beta, C, scale)\n",
|
||||
"\n",
|
||||
"torch.testing.assert_close(D, D_ref)\n",
|
||||
"torch.testing.assert_close(F, F_ref)\n",
|
||||
"torch.testing.assert_close(accum, accum_ref.to(accum.dtype))\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c641911f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Passing a string representation of the function\n",
|
||||
"`EpilogueArguments` can additionally be constructed using a string representation of the epilogue function:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5987bf44",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"epi_str = \"def epi(accum, alpha, beta, C): F = (accum * alpha) + (C * beta); D = relu(F); return D, F\"\n",
|
||||
"\n",
|
||||
"C = torch.randn((L, M, N), device=\"cuda\", dtype=torch.float16)\n",
|
||||
"alpha = 1.0\n",
|
||||
"beta = 0.5\n",
|
||||
"D = torch.empty((L, M, N), device=\"cuda\", dtype=torch.float16)\n",
|
||||
"F = torch.empty((L, M, N), device=\"cuda\", dtype=torch.float16)\n",
|
||||
"\n",
|
||||
"epi_args = EpilogueArguments(epi_str, alpha=alpha, beta=beta, C=C, D=D, F=F)\n",
|
||||
"args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
|
||||
"kernels = cutlass_api.get_kernels(args, cc=100)\n",
|
||||
"assert len(kernels) > 0\n",
|
||||
"kernels[0].run(args)\n",
|
||||
"\n",
|
||||
"F_ref = (A @ B) * alpha + (C * beta)\n",
|
||||
"D_ref = torch.relu(F_ref)\n",
|
||||
"\n",
|
||||
"torch.testing.assert_close(D, D_ref)\n",
|
||||
"torch.testing.assert_close(F, F_ref)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e26a58a2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Failure examples\n",
|
||||
"The following are examples of constructing `EpilogueArguments` that are expected to fail."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1e3d0c89",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"####################################################\n",
|
||||
"# Epilogues must take in an accumulator\n",
|
||||
"####################################################\n",
|
||||
"def fail_missing_accum(alpha, beta, C):\n",
|
||||
" D = (C * beta)\n",
|
||||
" return D\n",
|
||||
"\n",
|
||||
"try:\n",
|
||||
" epi_args = EpilogueArguments(fail_missing_accum, alpha=alpha, beta=beta, C=C, D=D)\n",
|
||||
" args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
|
||||
"except Exception as e:\n",
|
||||
" # \"accum must be an input to the epilogue function\"\n",
|
||||
" print(e)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "48a359f7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"####################################################\n",
|
||||
"# Epilogues must return an output named D\n",
|
||||
"####################################################\n",
|
||||
"def fail_missing_D(accum, alpha, beta, C):\n",
|
||||
" F = (accum * alpha) + (C * beta)\n",
|
||||
" return F\n",
|
||||
"\n",
|
||||
"try:\n",
|
||||
" epi_args = EpilogueArguments(fail_missing_D, alpha=alpha, beta=beta, C=C, F=F)\n",
|
||||
" args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
|
||||
"except Exception as e:\n",
|
||||
" # \"On SM90 or higher, D is expected to be a output node with 0 users to enable smem reuse between C and D, but got []\"\n",
|
||||
" print(e)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "49d9ee94",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"####################################################\n",
|
||||
"# Epilogues must use single-static assignment (SSA)\n",
|
||||
"####################################################\n",
|
||||
"def fail_ssa(accum):\n",
|
||||
" tmp = accum * 2.0\n",
|
||||
" # Redefine tmp, which violates SSA form.\n",
|
||||
" tmp = tmp - 1.0\n",
|
||||
" D = tmp / 4.0\n",
|
||||
" return D, tmp\n",
|
||||
"\n",
|
||||
"try:\n",
|
||||
" epi_args = EpilogueArguments(fail_ssa, D=D, tmp=F)\n",
|
||||
" args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
|
||||
"except Exception as e:\n",
|
||||
" # \"Variable 'tmp' cannot be defined twice.\"\n",
|
||||
" print(e)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "871bb727",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"####################################################\n",
|
||||
"# Must provide all operands and outputs to\n",
|
||||
"# EpilogueArguments\n",
|
||||
"####################################################\n",
|
||||
"def my_epi(accum, alpha, beta, C):\n",
|
||||
" F = (accum * alpha) + (C * beta)\n",
|
||||
" D = relu(F)\n",
|
||||
" return D\n",
|
||||
"\n",
|
||||
"try:\n",
|
||||
" # Missing D\n",
|
||||
" epi_args = EpilogueArguments(my_epi, alpha=alpha, beta=beta, C=C)\n",
|
||||
" args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
|
||||
"except Exception as e:\n",
|
||||
" # \"Argument D is not provided in the kwargs of the EpilogueArguments constructor\"\n",
|
||||
" print(e)\n",
|
||||
"\n",
|
||||
"try:\n",
|
||||
" # Missing alpha\n",
|
||||
" epi_args = EpilogueArguments(my_epi, beta=beta, C=C, D=D)\n",
|
||||
" args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
|
||||
"except Exception as e:\n",
|
||||
" # \"Argument alpha is not provided in the kwargs of the EpilogueArguments constructor\"\n",
|
||||
" print(e)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
548
python/cutlass_api/examples/002_bring_your_own_kernel.ipynb
Normal file
548
python/cutlass_api/examples/002_bring_your_own_kernel.ipynb
Normal file
@@ -0,0 +1,548 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "578f2730",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Adding a kernel to the CUTLASS API\n",
|
||||
"The CUTLASS API is designed to make it easy for users to add their own kernel\n",
|
||||
"so that it can be discovered and run under the uniform API. We welcome contributions\n",
|
||||
"toward the API by \"bringing your own kernel.\"\n",
|
||||
"\n",
|
||||
"This example shows how to add a CuTe DSL kernel to the CUTLASS API.\n",
|
||||
"\n",
|
||||
"## Bring your own implementation\n",
|
||||
"Individuals wishing to add a CuTe DSL kernel to the CUTLASS API likely already\n",
|
||||
"have the kernel written in CuTe DSL, but have not yet implemented the API's needed\n",
|
||||
"interface. Within the API, we separate these components into the \"implementation\" --\n",
|
||||
"the kernel written in CuTe DSL -- and the \"interface\" -- the definition of methods\n",
|
||||
"a kernel needs to be used within the CUTLASS API.\n",
|
||||
"\n",
|
||||
"For example, consider the following implementation of a simple FP64 GEMM kernel implementation:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5a64b0be",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from typing import Callable\n",
|
||||
"\n",
|
||||
"import cuda.bindings.driver as cuda\n",
|
||||
"\n",
|
||||
"import cutlass\n",
|
||||
"import cutlass.cute as cute\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class F64GemmKernelImplementation:\n",
|
||||
" def __init__(self, cta_tile_shape_mn: tuple[int, int]):\n",
|
||||
" self.cta_tile_shape_mn = cta_tile_shape_mn\n",
|
||||
"\n",
|
||||
" @cute.jit\n",
|
||||
" def __call__(\n",
|
||||
" self, a: cute.Tensor, b: cute.Tensor, out: cute.Tensor, stream: cuda.CUstream\n",
|
||||
" ):\n",
|
||||
" l, m, n = out.shape\n",
|
||||
" m_tiles = (m + self.cta_tile_shape_mn[0] - 1) // self.cta_tile_shape_mn[0]\n",
|
||||
" n_tiles = (n + self.cta_tile_shape_mn[1] - 1) // self.cta_tile_shape_mn[1]\n",
|
||||
"\n",
|
||||
" grid = (m_tiles, n_tiles, l)\n",
|
||||
" block = [self.cta_tile_shape_mn[0], self.cta_tile_shape_mn[1], 1]\n",
|
||||
" self.kernel(a, b, out).launch(grid=grid, block=block, stream=stream)\n",
|
||||
"\n",
|
||||
" @cute.kernel\n",
|
||||
" def kernel(self, a: cute.Tensor, b: cute.Tensor, out: cute.Tensor):\n",
|
||||
" l, m, n = out.shape\n",
|
||||
" k = a.shape[-1]\n",
|
||||
" m_tile, n_tile, l_idx = cute.arch.block_idx()\n",
|
||||
" tidx, tidy, _ = cute.arch.thread_idx()\n",
|
||||
"\n",
|
||||
" m_idx = m_tile * self.cta_tile_shape_mn[0] + tidx\n",
|
||||
" n_idx = n_tile * self.cta_tile_shape_mn[1] + tidy\n",
|
||||
"\n",
|
||||
" if m_idx < m and n_idx < n:\n",
|
||||
" out[l_idx, m_idx, n_idx] = cutlass.Float64(0)\n",
|
||||
" for k_idx in range(k):\n",
|
||||
" out[l_idx, m_idx, n_idx] += (\n",
|
||||
" a[l_idx, m_idx, k_idx] * b[l_idx, k_idx, n_idx]\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "36a08d4b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The implementation is configurable via a `cta_tile_shape_mn` argument, which\n",
|
||||
"controls the size of blocks and tiles in the M and N modes. A simple `cute.jit` function\n",
|
||||
"computes the grid and block size for the input problem based on `cta_tile_shape_mn`,\n",
|
||||
"and launches the kernel. The `cute.kernel` itself simply has each thread compute a single\n",
|
||||
"output element of the matrix by taking a dot product.\n",
|
||||
"\n",
|
||||
"This implementation is not performant, but is kept simple for illustrative purposes."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a5d0e661",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Defining interface methods\n",
|
||||
"As it currently stands, this GEMM kernel implementation cannot be used via the\n",
|
||||
"CUTLASS API because it does not implement interface methods. Specifically, kernels\n",
|
||||
"within the CUTLASS API must inherit from and implement the `cutlass_api.Kernel`\n",
|
||||
"abstract class. This class has methods needed for many common operations\n",
|
||||
"performed when compiling and executing DSL kernels.\n",
|
||||
"\n",
|
||||
"Certain providers (i.e., DSLs), such as CuTe DSL, provide an additional layer atop the\n",
|
||||
"`cutlass_api.Kernel` class to add utilities for kernels being written\n",
|
||||
"via that provider. For example, the CuTe DSL provider in the CUTLASS API\n",
|
||||
"defines `cutlass_api.providers.cutedsl.kernel.CuteDslKernel`, which adds utilities surrounding\n",
|
||||
"`cute.compile()` to add compile-time arguments needed for using TVM-FFI when\n",
|
||||
"it is enabled.\n",
|
||||
"\n",
|
||||
"We will next walk through the steps in defining interface methods for this\n",
|
||||
"implementation"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1a2da869",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import itertools\n",
|
||||
"\n",
|
||||
"import cutlass_api\n",
|
||||
"from cutlass_api.arguments import GemmArguments\n",
|
||||
"from cutlass_api.metadata import KernelMetadata\n",
|
||||
"from cutlass_api.status import Status"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "86ae75cc",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We begin by defining a class to represent the kernel's interface.\n",
|
||||
"As mentioned above, since this is a CuTe DSL kernel, our interface must\n",
|
||||
"inherit from and implement `cutlass_api.providers.cutedsl.kernel.CuteDslKernel`.\n",
|
||||
"\n",
|
||||
"The class must additionally be registered with the CuTe DSL provider\n",
|
||||
"via the `@CuTeDSLProvider.register` decorator so that the class\n",
|
||||
"can be considered when discovering kernels."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3a86d138",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@cutlass_api.providers.cutedsl.CuTeDSLProvider.register\n",
|
||||
"class F64GemmKernel(cutlass_api.providers.cutedsl.kernel.CuteDslKernel):\n",
|
||||
" # Empty versions of interface methods. These will be implemented later, interspersed\n",
|
||||
" # with notebook markdown. Normally, one would define them inline with the class definition.\n",
|
||||
" def __init__(self, metadata: KernelMetadata): pass\n",
|
||||
"\n",
|
||||
" def _run(self, args: GemmArguments, artifact: cutlass_api.artifact.CompiledArtifact, stream, workspace=None): pass\n",
|
||||
"\n",
|
||||
" def compile(self, args: GemmArguments, cc: int = None) -> cutlass_api.artifact.CompiledArtifact: pass\n",
|
||||
"\n",
|
||||
" @staticmethod\n",
|
||||
" def generate_kernels(metadata_filter, epilogue_args=None, cc=None) -> list[\"F64GemmKernel\"]: pass\n",
|
||||
"\n",
|
||||
" def _supports(self, args: GemmArguments) -> Status: pass\n",
|
||||
"\n",
|
||||
" def get_workspace_size(self, args: GemmArguments) -> int: pass"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "327e9e7c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The `__init__` method of the class takes in a `KernelMetadata` object\n",
|
||||
"from which it extracts the `cta_tile_shape_mn`. This is used to construct\n",
|
||||
"the kernel implementation object. We will discuss later how the `KernelMetadata`\n",
|
||||
"object passed in here is constructed:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "785d1882",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def __init__(self, metadata: KernelMetadata):\n",
|
||||
" self.metadata = metadata\n",
|
||||
" cta_tile_shape_mn = metadata.design.tile_shape[:2]\n",
|
||||
" self.impl = F64GemmKernelImplementation(cta_tile_shape_mn)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "500a0030",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Defining interfaces for compilation and execution\n",
|
||||
"The interfaces needed for compilation and execution are simple.\n",
|
||||
"\n",
|
||||
"The `compile` method simply constructs a placeholder stream object\n",
|
||||
"and passes that and relevant arguments to `self.cute_compile`. This\n",
|
||||
"is a utility defined in the `CuteDSLKernel` abstract class that\n",
|
||||
"passes in compilation flags needed for certain options to `cute.compile`\n",
|
||||
"(e.g., TVM-FFI). The result is wrapped as a `CompiledArtifact`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "63b4a129",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def compile(self, args: GemmArguments, cc: int = None) -> cutlass_api.artifact.CompiledArtifact:\n",
|
||||
" stream = cutlass.cute.runtime.make_fake_stream()\n",
|
||||
" compiled_gemm = self.cute_compile(self.impl, args.A, args.B, args.out, stream)\n",
|
||||
" return cutlass_api.artifact.CompiledArtifact(compiled_gemm, self)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "023127fd",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Users define the `_run` method rather than the top-level `run` method\n",
|
||||
"(no leading underscore) that is used in interacting with kernels. `_run` (1) extracts from `args`\n",
|
||||
"the arguments needed to run the JIT function, and (2) calls the JIT function\n",
|
||||
"passed in via `artifact` with these arguments."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2ae7c009",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def _run(self, args: GemmArguments, artifact: cutlass_api.artifact.CompiledArtifact, stream, workspace=None):\n",
|
||||
" stream = cutlass_api.utils.to_cuda_stream(stream)\n",
|
||||
" compiled_gemm = artifact.compiled_obj\n",
|
||||
" self.cute_run(compiled_gemm, args.A, args.B, args.out, stream)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4052e5a0",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Finally, since this kernel does not require any device workspace,\n",
|
||||
"we give it a simple `get_workspace_size` method that always returns 0."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "968906ea",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_workspace_size(self, args: GemmArguments) -> int:\n",
|
||||
" return 0"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e245a319",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Defining interfaces for kernel generation\n",
|
||||
"We have implemented the interfaces needed for constructing the kernel\n",
|
||||
"interface, compiling it, and running it. We now must implement methods for\n",
|
||||
"generating the possible configurations of this kernel that the kernel\n",
|
||||
"class itself supports. This will be used in kernel discovery (e.g., via\n",
|
||||
"`cutlass_api.get_kernels()`).\n",
|
||||
"\n",
|
||||
"To do so, we write the `generate_kernels` method. This takes in a\n",
|
||||
"binary function `metadata_filter`, epilogue arguments `epilogue_args`,\n",
|
||||
"and a compute capability `cc`. It returns a list of all instances\n",
|
||||
"of the kernel interface that support the `epilogue_args`, are compatible\n",
|
||||
"with the given `cc`, and which pass the `metadata_filter`.\n",
|
||||
"\n",
|
||||
"The `Kernel` class is responsible for defining what valid possible configurations (instances) of it can exist.\n",
|
||||
"In this example, the valid configurations involve a cross-product of row/column-major strides and two preset tile shapes.\n",
|
||||
"We create a nested loop over these knobs and create a `KernelMetadata` corresponding to each unique configuration.\n",
|
||||
"\n",
|
||||
"The `generate_kernels` method must additionally filter the generated kernels by passing it through a `metadata_filter`.\n",
|
||||
"This is a user-provided custom filter to filter generated metadata combinations. More information on `metadata_filter` is provided in other examples."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "47dc2f20",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@staticmethod\n",
|
||||
"def generate_kernels(\n",
|
||||
" metadata_filter: Callable[[KernelMetadata], bool],\n",
|
||||
" epilogue_args: cutlass_api.arguments.EpilogueArguments = None,\n",
|
||||
" cc: int = None,\n",
|
||||
") -> list[\"F64GemmKernel\"]:\n",
|
||||
"\n",
|
||||
" # The tile shapes this kernel supports/exposes\n",
|
||||
" supported_tile_shapes = [(32, 32, 1), (16, 16, 1)]\n",
|
||||
"\n",
|
||||
" if epilogue_args is not None:\n",
|
||||
" return []\n",
|
||||
"\n",
|
||||
" row_major_stride = (0, 0, 1)\n",
|
||||
" col_major_stride = (0, 1, 0)\n",
|
||||
" stride_combos = list(itertools.product([row_major_stride, col_major_stride], repeat=3))\n",
|
||||
" alignment = 1\n",
|
||||
"\n",
|
||||
" def stride_name(stride): \n",
|
||||
" return \"T\" if stride == row_major_stride else \"N\"\n",
|
||||
"\n",
|
||||
" kernels = []\n",
|
||||
" for tile_shape in supported_tile_shapes:\n",
|
||||
" design_metadata = cutlass_api.metadata.BLASDesignMetadata(tile_shape, (1, 1, 1))\n",
|
||||
" for stride_A, stride_B, stride_out in stride_combos:\n",
|
||||
" # Create TensorAttributes for A, B, and out tensors\n",
|
||||
" a_attrs = cutlass_api.metadata.TensorAttributes(cutlass.Float64, stride_A, alignment)\n",
|
||||
" b_attrs = cutlass_api.metadata.TensorAttributes(cutlass.Float64, stride_B, alignment)\n",
|
||||
" out_attrs = cutlass_api.metadata.TensorAttributes(cutlass.Float64, stride_out, alignment)\n",
|
||||
" layout_str = cutlass_api.utils.strides_to_layout_string(stride_A, stride_B, stride_out)\n",
|
||||
"\n",
|
||||
" name = f\"F64GemmKernel_tile{tile_shape[0]}x{tile_shape[1]}_{layout_str}\"\n",
|
||||
"\n",
|
||||
" metadata = KernelMetadata(\n",
|
||||
" kernel_name=name,\n",
|
||||
" kernel_class=F64GemmKernel,\n",
|
||||
" operands=cutlass_api.metadata.GemmOperandsMetadata(\n",
|
||||
" a_attrs, b_attrs, out_attrs, accumulator_type=cutlass.Float64\n",
|
||||
" ),\n",
|
||||
" design=design_metadata,\n",
|
||||
" min_cc=0,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" if metadata_filter(metadata):\n",
|
||||
" kernels.append(F64GemmKernel(metadata))\n",
|
||||
"\n",
|
||||
" return kernels"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c7cdbc66",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We also add a method for indicating whether a kernel instance in question\n",
|
||||
"supports a set of arguments. The top-level `Kernel.supports` method will\n",
|
||||
"already verify that the `args` passed in match the metadata with which\n",
|
||||
"this `Kernel` instance was constructed. Here, we define additional\n",
|
||||
"checks specific to this kernel, such as that the kernel expects\n",
|
||||
"all operands to be of rank 3:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "54067d47",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def _supports(self, args: GemmArguments) -> Status:\n",
|
||||
" if not (\n",
|
||||
" len(args.A.shape) == 3 and # A should be (L, M, K)\n",
|
||||
" len(args.B.shape) == 3 and # B should be (L, K, N)\n",
|
||||
" len(args.out.shape) == 3 # out should be (L, M, N)\n",
|
||||
" ):\n",
|
||||
" return Status.fail(\"All operands must be rank 3.\")\n",
|
||||
" return Status.success()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "edaf2cba",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Assign methods to the class because we interspersed notebook markdown\n",
|
||||
"# with the class definition. This is not needed in a real implementation.\n",
|
||||
"F64GemmKernel.__init__ = __init__\n",
|
||||
"F64GemmKernel.compile = compile\n",
|
||||
"F64GemmKernel._run = _run\n",
|
||||
"F64GemmKernel._supports = _supports\n",
|
||||
"F64GemmKernel.generate_kernels = generate_kernels\n",
|
||||
"F64GemmKernel.get_workspace_size = get_workspace_size"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c8fc84e9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Discovering instances of the kernel and using them\n",
|
||||
"The CUTLASS API is now prepared to discover instances of this\n",
|
||||
"kernel interface just as was done in previous examples.\n",
|
||||
"\n",
|
||||
"We add a small modification of using a `metadata_filter`\n",
|
||||
"to ensure that all returned kernels are instances of the\n",
|
||||
"`F64GemmKernel` class we just implemented. This is needed\n",
|
||||
"only for example/testing purposes."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "cec5431d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"torch.manual_seed(2025)\n",
|
||||
"\n",
|
||||
"L, M, N, K = 1, 256, 1024, 128\n",
|
||||
"A = torch.randn(L, M, K, device=\"cuda\", dtype=torch.float64)\n",
|
||||
"B = torch.randn(L, K, N, device=\"cuda\", dtype=torch.float64)\n",
|
||||
"out = torch.empty(L, M, N, device=\"cuda\", dtype=torch.float64)\n",
|
||||
"\n",
|
||||
"args = GemmArguments(A, B, out, accumulator_type=torch.float64)\n",
|
||||
"\n",
|
||||
"def is_f64gemm_kernel(metadata):\n",
|
||||
" return metadata.kernel_class == F64GemmKernel\n",
|
||||
"\n",
|
||||
"kernels = cutlass_api.get_kernels(args, metadata_filter=is_f64gemm_kernel)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "50e81a7d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can print off the names of the first few kernels to see that\n",
|
||||
"they come from our recently-added kernel."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "cdb92b5e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(kernels[0].metadata.kernel_name)\n",
|
||||
"print(kernels[1].metadata.kernel_name)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "697ee3c3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can evaluate and test the correctness of an instance of our kernel:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f5486244",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"kernels[0].run(args)\n",
|
||||
"torch.testing.assert_close(out, A @ B)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "8de96f7e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can also test the limits of our kernel's design space by providing a\n",
|
||||
"metadata filter that expects a CTA tile size M of 256, which is not exposed\n",
|
||||
"in the `generate_kernels` method of our recently-added kernel. We expect\n",
|
||||
"no kernels of type `F64GemmKernel` to be returned."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "917c74e3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def my_filter(metadata):\n",
|
||||
" return (\n",
|
||||
" is_f64gemm_kernel(metadata) and\n",
|
||||
" isinstance(metadata.design, cutlass_api.metadata.BLASDesignMetadata) and\n",
|
||||
" metadata.design.tile_shape[0] == 256\n",
|
||||
" )\n",
|
||||
"kernels_ctam256 = cutlass_api.get_kernels(args, metadata_filter=my_filter)\n",
|
||||
"\n",
|
||||
"# No kernels should be found\n",
|
||||
"assert len(kernels_ctam256) == 0"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "caa80a7d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## A note on contributing kernels to directory structure\n",
|
||||
"This example showed how to define a kernel inline and add it to the\n",
|
||||
"API for example purposes. This kernel doesn't necessarily need to live\n",
|
||||
"within the API's source code.\n",
|
||||
"\n",
|
||||
"We welcome contributions of kernels that do live within the CUTLASS\n",
|
||||
"API's repository as well.\n",
|
||||
"\n",
|
||||
"Kernels in the repository are organized based on the \"provider\" in which they are\n",
|
||||
"authored (i.e., the DSL). All kernels corresponding to a given\n",
|
||||
"provider live a directory corresponding to that provider under\n",
|
||||
"`cutlass_api/providers`. For example, CuTe DSL kernels live\n",
|
||||
"under `cutlass_api/providers/cutedsl`.\n",
|
||||
"\n",
|
||||
"Each provider can organize kernels differently. For CuTe DSL,\n",
|
||||
"kernels are further split based on their logical operation,\n",
|
||||
"with GEMM kernels under the `cutlass_api/providers/cutedsl/gemm`\n",
|
||||
"directory.\n",
|
||||
"\n",
|
||||
"We recommend separating the implementation of the kernel from\n",
|
||||
"its interface not just by using separate classes, as done in\n",
|
||||
"this example, but also by separating the implementation and\n",
|
||||
"interface into separate files. This makes it easier to update\n",
|
||||
"each without affecting the other.\n",
|
||||
"\n",
|
||||
"For example, CuTe DSL GEMM kernels have the following organization:\n",
|
||||
"```text\n",
|
||||
"cutlass_api/\n",
|
||||
" providers/\n",
|
||||
" cutedsl/\n",
|
||||
" gemm/\n",
|
||||
" sm100_static_persistent.py\n",
|
||||
" implementations/\n",
|
||||
" sm100_static_persistent_impl.py\n",
|
||||
"```"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -0,0 +1,521 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f97e61c9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Best practices for reducing host-side latency"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a7a9c63c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Overall performance depends on both device performance (i.e., that of the kernel) and host performance (i.e., that of the runtime).\n",
|
||||
"This notebook focuses on the latter: techniques to minimize any overheads incurred from the CUTLASS API and underlying\n",
|
||||
"DSL runtimes.\n",
|
||||
"\n",
|
||||
"This notebook does not discuss techniques for improving device-side performance. A future notebook may cover this topic."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "e3ca9e40",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import time\n",
|
||||
"import torch\n",
|
||||
"import cutlass_api"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "efaac09c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if not (status := cutlass_api.utils.is_device_cc_supported({100, 103})):\n",
|
||||
" print(f\"This notebook requires a GPU with compute capability 100 or 103.\\n{status.error}\")\n",
|
||||
" import sys\n",
|
||||
"\n",
|
||||
" sys.exit(0)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "40de11ce",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We start with boilerplate initial setup to create tensors and pick a kernel.\n",
|
||||
"\n",
|
||||
"For the purposes of this notebook, we use a very small GEMM size of M=N=K=128\n",
|
||||
"and L=1. This small size is chosen to magnify the impact of host latency on\n",
|
||||
"end-to-end performance so as to better illustrate the effect of the techniques\n",
|
||||
"described below."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "b8c44947",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"warmup_iterations = 10\n",
|
||||
"profiling_iterations = 100\n",
|
||||
"total_iterations = warmup_iterations + profiling_iterations\n",
|
||||
"\n",
|
||||
"# Use a small problem size to showcase host overheads\n",
|
||||
"L, M, N, K = 1, 128, 128, 128\n",
|
||||
"\n",
|
||||
"# We use different operands in each iteration. Though not particularly relevant for\n",
|
||||
"# host latency, this is a best practice when benchmarking GPU kernels to avoid\n",
|
||||
"# unrealistic caching effects.\n",
|
||||
"As = [torch.randint(-1, 2, (M, K), device=\"cuda\", dtype=torch.float16) for _ in range(total_iterations)]\n",
|
||||
"Bs = [torch.randint(-1, 2, (K, N), device=\"cuda\", dtype=torch.float16) for _ in range(total_iterations)]\n",
|
||||
"outs = [torch.empty((M, N), device=\"cuda\", dtype=torch.float16) for _ in range(total_iterations)]\n",
|
||||
"\n",
|
||||
"# Construct arguments outside of the benchmarking loop. We will later also consider\n",
|
||||
"# cases in which they are constructed inside the benchmarking loop.\n",
|
||||
"args = [cutlass_api.arguments.GemmArguments(A=As[i], B=Bs[i], out=outs[i], accumulator_type=torch.float16) for i in range(total_iterations)]\n",
|
||||
"\n",
|
||||
"references = [(As[i] @ Bs[i]).to(outs[i].dtype) for i in range(total_iterations)]\n",
|
||||
"\n",
|
||||
"kernels = cutlass_api.get_kernels(args[0], cc=100)\n",
|
||||
"\n",
|
||||
"assert len(kernels) > 0\n",
|
||||
"kernel = kernels[0]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f2e7eece",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We next set up a basic benchmarking routine."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "2472eafa",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def benchmark(label, code, warmup_it=warmup_iterations, profiling_it=profiling_iterations):\n",
|
||||
" total_it = warmup_it + profiling_it\n",
|
||||
" assert total_it <= total_iterations, f\"Benchmark-local iteration count must be less than or equal to total iterations: {total_it} > {total_iterations}\"\n",
|
||||
" # warmup\n",
|
||||
" rets = [None] * total_it\n",
|
||||
" for i in range(warmup_it):\n",
|
||||
" rets[i] = code(i)\n",
|
||||
" torch.cuda.synchronize()\n",
|
||||
"\n",
|
||||
" start = time.time()\n",
|
||||
" for i in range(profiling_it):\n",
|
||||
" idx = warmup_it + i\n",
|
||||
" rets[idx] = code(idx)\n",
|
||||
" torch.cuda.synchronize()\n",
|
||||
" end = time.time()\n",
|
||||
"\n",
|
||||
" avg_time = (end - start) / profiling_it\n",
|
||||
" print(f\"[{label:<30}] avg of {profiling_it} iterations: {avg_time:1.3e} seconds\")\n",
|
||||
" return avg_time, rets"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4909a76b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We now describe techniques for reducing host latency:\n",
|
||||
"* Compile once, run many times\n",
|
||||
"* Bypassing checks for argument-kernel compatibility\n",
|
||||
"* Using [CUDA Graphs](https://developer.nvidia.com/blog/cuda-graphs/)\n",
|
||||
"* Using [TVM FFI](https://tvm.apache.org/ffi/)\n",
|
||||
"\n",
|
||||
"These techniques are complementary and should be used together when applicable\n",
|
||||
"for an application."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "06495033",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Compile once, run many times\n",
|
||||
"The `kernel.run` method takes in an optional `compiled_artifact` argument of type\n",
|
||||
"`cutlass_api.artifact.CompiledArtifact`. When this argument is set, the kernel\n",
|
||||
"will directly use the precompiled function within `compiled_argument`. When\n",
|
||||
"it is not set, the call to `kernel.run` will JIT compile the kernel on each\n",
|
||||
"invocation.\n",
|
||||
"\n",
|
||||
"Precompiling the kernel is critical to achieving good performance."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "6de11f56",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"stream = torch.cuda.current_stream()\n",
|
||||
"def no_compiled_artifact(i: int):\n",
|
||||
" return kernel.run(args[i], stream=stream)\n",
|
||||
"\n",
|
||||
"# Compile the kernel once, reuse for each iterations\n",
|
||||
"compiled_artifact = kernel.compile(args[0])\n",
|
||||
"\n",
|
||||
"def with_compiled_artifact(i: int):\n",
|
||||
" return kernel.run(args[i], stream=stream, compiled_artifact=compiled_artifact)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "350c9bd6",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[Without compiled artifact ] avg of 5 iterations: 1.376e+00 seconds\n",
|
||||
"[With compiled artifact ] avg of 5 iterations: 1.016e-05 seconds\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"time_no_artifact, _ = benchmark(f\"Without compiled artifact\", no_compiled_artifact, warmup_it=2, profiling_it=5)\n",
|
||||
"time_w_artifact, _ = benchmark(f\"With compiled artifact\", with_compiled_artifact, warmup_it=2, profiling_it=5)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5cfbc2d2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Bypassing checks for argument-kernel compatibility\n",
|
||||
"By default, the call to `kernel.run` will check if the kernel supports the provided arguments.\n",
|
||||
"Under the hood, this invokes `kernel.supports(args)`.\n",
|
||||
"\n",
|
||||
"While these checks are helpful for catching incompatible arguments, they are performed\n",
|
||||
"in Python, and thus can add to host overhead.\n",
|
||||
"\n",
|
||||
"When confident that arguments will be compatible with a kernel, one should bypass\n",
|
||||
"the `supports` check in `kernel.run` by setting the optional `assume_supported_args`\n",
|
||||
"argument to `True`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "5b93dfae",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def with_supports_check(i: int):\n",
|
||||
" return kernel.run(args[i], compiled_artifact=compiled_artifact, stream=stream, assume_supported_args=False)\n",
|
||||
"\n",
|
||||
"def without_supports_check(i: int):\n",
|
||||
" return kernel.run(args[i], compiled_artifact=compiled_artifact, stream=stream, assume_supported_args=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "b282f437",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[With supports check ] avg of 100 iterations: 1.463e-05 seconds\n",
|
||||
"[Bypass supports check ] avg of 100 iterations: 6.239e-06 seconds\n",
|
||||
"Speedup with skip supports: 2.34x\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"time_w_supports, _ = benchmark(\"With supports check\", with_supports_check)\n",
|
||||
"time_wo_supports, _ = benchmark(\"Bypass supports check\", without_supports_check)\n",
|
||||
"print(f\"Speedup with skip supports: {time_w_supports / time_wo_supports:.2f}x\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d74cb3e7",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### CUDA Graphs"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "656d5e2c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"CUTLASS API supports [CUDA Graphs](https://developer.nvidia.com/blog/cuda-graphs/) usage with PyTorch as usual.\n",
|
||||
"\n",
|
||||
"The kernel compilation must happen outside the CUDA graph. Then, we create a graph using usual PyTorch idioms to launch a kernel several times on the graph's stream."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "e614509f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"num_launches = 20\n",
|
||||
"\n",
|
||||
"# Create a CUDA Graph to run our compiled kernel N times\n",
|
||||
"g = torch.cuda.CUDAGraph()\n",
|
||||
"with torch.cuda.graph(g):\n",
|
||||
" # Run N iterations of our compiled kernel on the current stream\n",
|
||||
" for i in range(num_launches):\n",
|
||||
" kernel.run(\n",
|
||||
" args[i],\n",
|
||||
" compiled_artifact=compiled_artifact,\n",
|
||||
" stream=torch.cuda.current_stream(),\n",
|
||||
" assume_supported_args=True,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"# Zero the output so we don't refcheck stale results\n",
|
||||
"_ = outs[0].zero_()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "8fc69c6e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Once captured, we can replay the graph. This will only replay the kernel launches placed on the CUDA stream.\n",
|
||||
"Any other prepratory work on the host and arguments passed in from python are cached during the capture."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "d9c5d5c5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Replay captured graph and check first result\n",
|
||||
"g.replay()\n",
|
||||
"\n",
|
||||
"torch.testing.assert_close(outs[0], references[0])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "388c8e02",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's compare the timing:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "45d4e739",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[20 launches without CUDA Graph] avg of 1 iterations: 4.699e-04 seconds\n",
|
||||
"[20 launches with CUDA Graph ] avg of 1 iterations: 9.084e-05 seconds\n",
|
||||
"Speedup with CUDA Graph: 5.17x\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"def without_cuda_graph(x: int):\n",
|
||||
" for i in range(num_launches):\n",
|
||||
" kernel.run(\n",
|
||||
" args[i],\n",
|
||||
" compiled_artifact=compiled_artifact,\n",
|
||||
" stream=torch.cuda.current_stream(),\n",
|
||||
" assume_supported_args=True,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"def with_cuda_graph(x: int):\n",
|
||||
" g.replay()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"time_wo_cuda_graph, _ = benchmark(f\"{num_launches} launches without CUDA Graph\", without_cuda_graph, warmup_it=0, profiling_it=1)\n",
|
||||
"time_w_cuda_graph, _ = benchmark(f\"{num_launches} launches with CUDA Graph\", with_cuda_graph, warmup_it=0, profiling_it=1)\n",
|
||||
"\n",
|
||||
"print(f\"Speedup with CUDA Graph: {time_wo_cuda_graph / time_w_cuda_graph:.2f}x\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "fe5c3168",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### TVM FFI"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ee7f9fd2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"When applicable, CUTLASS API uses [Apache TVM FFI](https://tvm.apache.org/ffi/) under the hood for invoking compiled DSL kernels from Python.\n",
|
||||
"Apache TVM FFI is an open ABI and FFI for machine learning systems.\n",
|
||||
"\n",
|
||||
"TVM FFI is enabled by default in CUTLASS API, and is recommended for best performance."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1690bbed",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"`cutlass_api.config.GlobalOptions().use_tvm_ffi` controls whether or not TVM-FFI will be used by CUTLASS API."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"id": "993c60ae",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"True\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(cutlass_api.config.GlobalOptions().use_tvm_ffi)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "00ed9a40",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"If for some reason you do not wish to use it, this section demonstrates how, you can set this to False. No other change is needed. The below code compares the performance with and without TVM-FFI."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"id": "e8f56be3",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[[TVM-FFI ON ] Create args ] avg of 100 iterations: 8.367e-05 seconds\n",
|
||||
"[[TVM-FFI ON ] Compile kernel ] avg of 5 iterations: 1.352e+00 seconds\n",
|
||||
"[[TVM-FFI ON ] Run kernel ] avg of 100 iterations: 6.509e-06 seconds\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"cutlass_api.config.GlobalOptions().use_tvm_ffi = True\n",
|
||||
"\n",
|
||||
"def run_iteration(i):\n",
|
||||
" args = cutlass_api.arguments.GemmArguments(A=As[i], B=Bs[i], out=outs[i], accumulator_type=torch.float16)\n",
|
||||
" return kernel.run(\n",
|
||||
" args,\n",
|
||||
" compiled_artifact=compiled_artifact,\n",
|
||||
" stream=torch.cuda.current_stream(),\n",
|
||||
" assume_supported_args=True,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"def create_arguments(i: int):\n",
|
||||
" return cutlass_api.arguments.GemmArguments(A=As[i], B=Bs[i], out=outs[i], accumulator_type=torch.float16)\n",
|
||||
"\n",
|
||||
"args_creation_on, args = benchmark(\"[TVM-FFI ON ] Create args\", create_arguments)\n",
|
||||
"compilation_on, compiled = benchmark(\"[TVM-FFI ON ] Compile kernel\", lambda i: kernel.compile(args[i]), warmup_it=2, profiling_it=5)\n",
|
||||
"compiled_artifact = compiled[0]\n",
|
||||
"run_on, _ = benchmark(\"[TVM-FFI ON ] Run kernel\", lambda i: kernel.run(args[i], compiled_artifact=compiled_artifact, assume_supported_args=True, stream=stream))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"id": "5a4c2db4",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[[TVM-FFI OFF ] Create args ] avg of 100 iterations: 1.255e-04 seconds\n",
|
||||
"[[TVM-FFI OFF ] Compile kernel ] avg of 5 iterations: 1.278e+00 seconds\n",
|
||||
"[[TVM-FFI OFF ] Run kernel ] avg of 100 iterations: 4.519e-05 seconds\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"cutlass_api.config.GlobalOptions().use_tvm_ffi = False\n",
|
||||
"args_creation_off, args = benchmark(\"[TVM-FFI OFF ] Create args\", create_arguments)\n",
|
||||
"compilation_off, compiled = benchmark(\"[TVM-FFI OFF ] Compile kernel\", lambda i: kernel.compile(args[i]), warmup_it=2, profiling_it=5)\n",
|
||||
"compiled_artifact = compiled[0]\n",
|
||||
"run_off, _ = benchmark(\"[TVM-FFI OFF ] Run kernel\", lambda i: kernel.run(args[i], compiled_artifact=compiled_artifact, assume_supported_args=True, stream=stream))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"id": "17b43718",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Speedups with TVM-FFI: \n",
|
||||
"Arg creation: 1.50x\n",
|
||||
"Compilation: 0.95x\n",
|
||||
"Run: 6.94x\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(\"Speedups with TVM-FFI: \")\n",
|
||||
"print(f\"Arg creation: {args_creation_off / args_creation_on:.2f}x\")\n",
|
||||
"print(f\"Compilation: {compilation_off / compilation_on:.2f}x\")\n",
|
||||
"print(f\"Run: {run_off / run_on:.2f}x\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
76
python/cutlass_api/pyproject.toml
Normal file
76
python/cutlass_api/pyproject.toml
Normal file
@@ -0,0 +1,76 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools>=61.0"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "cutlass_api"
|
||||
dynamic = ["version"]
|
||||
requires-python = ">=3.12"
|
||||
|
||||
dependencies = [
|
||||
"nvidia-cutlass-dsl",
|
||||
"apache-tvm-ffi", # Required for optimal performance. To opt-out: uninstall or set cutlass_api.config.GlobalOptions().use_tvm_ffi=False
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
torch = [
|
||||
"torch",
|
||||
"torch-c-dlpack-ext",
|
||||
]
|
||||
test = [
|
||||
"jupyter",
|
||||
"pytest",
|
||||
"cutlass_api[torch]",
|
||||
]
|
||||
dev = [
|
||||
"cutlass_api[test]",
|
||||
]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["."]
|
||||
|
||||
[tool.setuptools.dynamic]
|
||||
version = {attr = "cutlass_api.__version__"}
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py312"
|
||||
src = ["cutlass_api"]
|
||||
|
||||
# Exclude fusion/ from all ruff checks
|
||||
extend-exclude = [
|
||||
"cutlass_api/fusion",
|
||||
]
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
"F", # pyflakes
|
||||
"E", # pycodestyle errors
|
||||
"W", # pycodestyle warnings
|
||||
"I", # isort
|
||||
"UP", # pyupgrade
|
||||
"B", # flake8-bugbear
|
||||
"SIM", # flake8-simplify
|
||||
"C4", # flake8-comprehensions
|
||||
"PTH", # flake8-use-pathlib
|
||||
"FA", # flake8-future-annotations
|
||||
"PERF", # perflint - performance anti-patterns
|
||||
"FURB", # refurb - modern Python idioms
|
||||
"PIE", # flake8-pie - misc improvements
|
||||
"TCH", # flake8-type-checking
|
||||
"PLE", # pylint errors
|
||||
"RSE", # flake8-raise
|
||||
]
|
||||
ignore = [
|
||||
"SIM103", # nedless-bool - can be noisy
|
||||
"B905", # zip(strict=) - too strict
|
||||
]
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
section-order = ["future", "standard-library", "third-party", "cutlass", "first-party", "local-folder"]
|
||||
known-first-party = ["cutlass_api"]
|
||||
|
||||
[tool.ruff.lint.isort.sections]
|
||||
cutlass = ["cutlass"]
|
||||
166
python/cutlass_api/test/integration/test_cuda_graph.py
Normal file
166
python/cutlass_api/test/integration/test_cuda_graph.py
Normal file
@@ -0,0 +1,166 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.cuda import current_stream
|
||||
|
||||
import cutlass_api
|
||||
from cutlass_api.utils import is_device_cc_supported
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"M, N, K",
|
||||
[
|
||||
(256, 512, 1024),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"ab_dtype, c_dtype, accumulator_type",
|
||||
[
|
||||
(torch.float16, torch.float16, torch.float16),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"n_iterations",
|
||||
[
|
||||
20,
|
||||
],
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not is_device_cc_supported({100})
|
||||
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_100f"]),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
|
||||
)
|
||||
def test_gemm_sm100(
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
ab_dtype: torch.dtype,
|
||||
c_dtype: torch.dtype,
|
||||
accumulator_type: torch.dtype,
|
||||
n_iterations: int,
|
||||
):
|
||||
A = torch.randint(-1, 2, (M, K), device="cuda").to(ab_dtype)
|
||||
B = torch.randint(-1, 2, (K, N), device="cuda").to(ab_dtype)
|
||||
D = torch.randint(-1, 2, (M, N), device="cuda").to(c_dtype)
|
||||
|
||||
args = cutlass_api.arguments.GemmArguments(
|
||||
A=A, B=B, out=D, accumulator_type=accumulator_type
|
||||
)
|
||||
kernels = cutlass_api.get_kernels(args, cc=100)
|
||||
|
||||
assert len(kernels) > 0
|
||||
kernel = kernels[0]
|
||||
|
||||
"""
|
||||
Compile the kernel and capture CUDA Graph.
|
||||
The kernel needs to be compiled outside the CUDA graph.
|
||||
"""
|
||||
|
||||
assert kernel.supports(args)
|
||||
compiled_artifact = kernel.compile(args)
|
||||
stream = torch.cuda.Stream()
|
||||
|
||||
# Create a CUDA Graph to run our compiled kernel N times
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g):
|
||||
# Run N iterations of our compiled kernel on the current stream
|
||||
for _ in range(n_iterations):
|
||||
kernel.run(
|
||||
args,
|
||||
compiled_artifact=compiled_artifact,
|
||||
stream=current_stream(),
|
||||
assume_supported_args=True,
|
||||
)
|
||||
|
||||
# Zero the output so we don't refcheck stale results
|
||||
D.zero_()
|
||||
|
||||
# Replay captured graph & check result
|
||||
g.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
reference = A @ B
|
||||
assert torch.allclose(D, reference.to(D.dtype)), "Refcheck failed!"
|
||||
|
||||
"""
|
||||
Run with & without graph capture to compare overhead
|
||||
"""
|
||||
|
||||
# Create CUDA events for measuring performance
|
||||
start = torch.cuda.Event(enable_timing=True)
|
||||
end = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
# Warmup the GPU
|
||||
for _ in range(n_iterations):
|
||||
kernel.run(
|
||||
args,
|
||||
compiled_artifact=compiled_artifact,
|
||||
stream=stream,
|
||||
assume_supported_args=True,
|
||||
)
|
||||
|
||||
# Run without CUDA graph and time it
|
||||
start.record()
|
||||
for _ in range(n_iterations):
|
||||
kernel.run(
|
||||
args,
|
||||
compiled_artifact=compiled_artifact,
|
||||
stream=stream,
|
||||
assume_supported_args=True,
|
||||
)
|
||||
end.record()
|
||||
torch.cuda.synchronize()
|
||||
without_graph_time = start.elapsed_time(end)
|
||||
|
||||
# Warmup again
|
||||
for _ in range(n_iterations):
|
||||
kernel.run(
|
||||
args,
|
||||
compiled_artifact=compiled_artifact,
|
||||
stream=stream,
|
||||
assume_supported_args=True,
|
||||
)
|
||||
|
||||
# Run with CUDA graph and time it
|
||||
start.record()
|
||||
g.replay()
|
||||
end.record()
|
||||
torch.cuda.synchronize()
|
||||
with_graph_time = start.elapsed_time(end)
|
||||
|
||||
percent_speedup = (without_graph_time - with_graph_time) / with_graph_time
|
||||
print("-" * 80)
|
||||
print(f"Number of launches : {n_iterations}")
|
||||
print(f"Time without CUDA graph: {without_graph_time:.2f} ms")
|
||||
print(f"Time with CUDA graph: {with_graph_time:.2f} ms")
|
||||
print(f"Speedup : {percent_speedup * 100.0:.2f}%")
|
||||
69
python/cutlass_api/test/integration/test_elementwise_add.py
Normal file
69
python/cutlass_api/test/integration/test_elementwise_add.py
Normal file
@@ -0,0 +1,69 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import cutlass_api
|
||||
from cutlass_api.utils import device_cc
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"M, N",
|
||||
[
|
||||
(256, 512),
|
||||
(1024, 8192),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"dtype",
|
||||
[
|
||||
torch.float32,
|
||||
torch.float16,
|
||||
],
|
||||
)
|
||||
def test_elementwise_add(
|
||||
M: int,
|
||||
N: int,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
A = torch.randint(-1, 2, (M, N), device="cuda", dtype=dtype)
|
||||
B = torch.randint(-1, 2, (M, N), device="cuda", dtype=dtype)
|
||||
D = torch.empty((M, N), device="cuda", dtype=dtype)
|
||||
|
||||
args = cutlass_api.arguments.ElementwiseArguments(A=A, B=B, out=D)
|
||||
kernels = cutlass_api.get_kernels(args)
|
||||
|
||||
assert len(kernels) > 0
|
||||
kernel = kernels[0]
|
||||
kernel.run(args)
|
||||
|
||||
reference = A + B
|
||||
|
||||
assert torch.allclose(D, reference)
|
||||
211
python/cutlass_api/test/integration/test_gemm.py
Normal file
211
python/cutlass_api/test/integration/test_gemm.py
Normal file
@@ -0,0 +1,211 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
import logging
|
||||
import os
|
||||
from importlib.util import find_spec
|
||||
from pprint import pformat
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import cutlass
|
||||
|
||||
import cutlass_api
|
||||
from cutlass_api.config import GlobalOptions
|
||||
from cutlass_api.utils import is_device_cc_supported
|
||||
|
||||
torch.manual_seed(2025)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"M, N, K, L",
|
||||
[
|
||||
(256, 512, 1024, 1),
|
||||
(256, 512, 64, 1),
|
||||
(256, 512, 64, 2),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"ab_dtype, c_dtype, accumulator_type",
|
||||
[
|
||||
(torch.float16, torch.float32, torch.float32),
|
||||
(torch.float16, torch.float16, torch.float16),
|
||||
(torch.bfloat16, torch.bfloat16, torch.float32),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"use_tvm_ffi",
|
||||
[True, False],
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not is_device_cc_supported({100})
|
||||
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_100f"]),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
|
||||
)
|
||||
def test_gemm_sm100(
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
L: int,
|
||||
ab_dtype: torch.dtype,
|
||||
c_dtype: torch.dtype,
|
||||
accumulator_type: torch.dtype,
|
||||
use_tvm_ffi: bool,
|
||||
):
|
||||
A = torch.randint(-1, 2, (L, M, K), device="cuda", dtype=ab_dtype)
|
||||
B = torch.randint(-1, 2, (L, K, N), device="cuda", dtype=ab_dtype)
|
||||
D = torch.empty((L, M, N), device="cuda", dtype=c_dtype)
|
||||
|
||||
GlobalOptions().use_tvm_ffi = use_tvm_ffi
|
||||
|
||||
args = cutlass_api.arguments.GemmArguments(A, B, D, accumulator_type)
|
||||
|
||||
kernels = cutlass_api.get_kernels(args, cc=100)
|
||||
|
||||
assert len(kernels) > 0
|
||||
kernel = kernels[0]
|
||||
logger.debug(f"Picked kernel: {kernel.metadata.kernel_name}")
|
||||
logger.debug(f"Kernel metadata:\n{pformat(kernel.metadata)}")
|
||||
kernel.run(args)
|
||||
|
||||
reference = A @ B
|
||||
torch.testing.assert_close(D, reference.to(D.dtype))
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_device_cc_supported({100})
|
||||
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a"]),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a",
|
||||
)
|
||||
def test_gemm_sm100_int8():
|
||||
M, N, K = 256, 512, 128
|
||||
A = torch.randint(-1, 2, (M, K), device="cuda", dtype=torch.int8)
|
||||
B = torch.randint(-1, 2, (K, N), device="cuda", dtype=torch.int8)
|
||||
D = torch.empty((M, N), device="cuda", dtype=torch.int32)
|
||||
|
||||
args = cutlass_api.arguments.GemmArguments(A, B, D, accumulator_type=torch.int32)
|
||||
kernels = cutlass_api.get_kernels(args, cc=100)
|
||||
|
||||
assert len(kernels) > 0
|
||||
kernel = kernels[0]
|
||||
assert kernel.supports(args)
|
||||
compiled_artifact = kernel.compile(args)
|
||||
kernel.run(args, compiled_artifact=compiled_artifact, assume_supported_args=True)
|
||||
|
||||
reference = torch._int_mm(A, B)
|
||||
torch.testing.assert_close(D, reference.to(D.dtype))
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
find_spec("tvm_ffi") is None,
|
||||
reason="FP8 currently requires TVM FFI to be installed",
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not is_device_cc_supported({100})
|
||||
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a"]),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a",
|
||||
)
|
||||
def test_gemm_sm100_fp8():
|
||||
# Currently, FP8 is only supported via TVM FFI.
|
||||
GlobalOptions().use_tvm_ffi = True
|
||||
|
||||
M, N, K = 256, 512, 128
|
||||
# Create torch fp8 tensors for A and B
|
||||
A = torch.randint(-1, 2, (M, K), device="cuda").to(torch.float8_e4m3fn)
|
||||
D = torch.empty((M, N), device="cuda", dtype=torch.float32)
|
||||
|
||||
# Transpose B because torch._scaled_mm expects B to be column-major
|
||||
B = torch.randint(-1, 2, (N, K), device="cuda").to(torch.float8_e4m3fn).T
|
||||
|
||||
args = cutlass_api.arguments.GemmArguments(A, B, D, accumulator_type=torch.float32)
|
||||
kernels = cutlass_api.get_kernels(args, cc=100)
|
||||
|
||||
assert len(kernels) > 0
|
||||
kernel = kernels[0]
|
||||
assert kernel.supports(args)
|
||||
compiled_artifact = kernel.compile(args)
|
||||
kernel.run(args, compiled_artifact=compiled_artifact, assume_supported_args=True)
|
||||
|
||||
identity_scale = torch.ones(1, device="cuda", dtype=torch.float32)
|
||||
reference = torch._scaled_mm(
|
||||
A, B, identity_scale, identity_scale, out_dtype=torch.float32
|
||||
)
|
||||
torch.testing.assert_close(D, reference)
|
||||
|
||||
|
||||
def test_no_gemms_available():
|
||||
M = N = K = 128
|
||||
L = 1
|
||||
A = torch.empty((L, M, K)).to(torch.float32)
|
||||
B = torch.empty((L, K, N)).to(torch.float32)
|
||||
D = torch.empty((L, M, N)).to(torch.float32)
|
||||
|
||||
args = cutlass_api.arguments.GemmArguments(A, B, D, accumulator_type=torch.float32)
|
||||
kernels = cutlass_api.get_kernels(args, cc=70)
|
||||
|
||||
# There are currenlty no kernels available for compute capability 70.
|
||||
assert len(kernels) == 0
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_device_cc_supported({100}),
|
||||
reason="Requires compute capability 100",
|
||||
)
|
||||
def test_metadata_filter():
|
||||
# Test supplying metadata filter only
|
||||
def tile_size_m_filter(metadata: cutlass_api.metadata.KernelMetadata) -> bool:
|
||||
if not isinstance(metadata.design, cutlass_api.metadata.Sm100DesignMetadata):
|
||||
return False
|
||||
return metadata.design.tile_shape[0] == 64
|
||||
|
||||
kernels = cutlass_api.get_kernels(cc=100, metadata_filter=tile_size_m_filter)
|
||||
for kernel in kernels:
|
||||
assert kernel.metadata.design.tile_shape[0] == 64, (
|
||||
f"Kernel {kernel.metadata.kernel_name} has tile shape {kernel.metadata.design.tile_shape}"
|
||||
)
|
||||
|
||||
# Test supplying metadata filter and arguments
|
||||
A = torch.randint(-1, 2, (1, 256, 256), device="cuda").to(torch.float16)
|
||||
B = torch.randint(-1, 2, (1, 256, 256), device="cuda").to(torch.float16)
|
||||
D = torch.empty((1, 256, 256), device="cuda").to(torch.float16)
|
||||
|
||||
args = cutlass_api.arguments.GemmArguments(A, B, D, accumulator_type=torch.float16)
|
||||
kernels = cutlass_api.get_kernels(
|
||||
args=args, cc=100, metadata_filter=tile_size_m_filter
|
||||
)
|
||||
for kernel in kernels:
|
||||
assert kernel.metadata.design.tile_shape[0] == 64, (
|
||||
f"Kernel {kernel.metadata.kernel_name} has tile shape {kernel.metadata.design.tile_shape}"
|
||||
)
|
||||
assert kernel.metadata.operands.A.dtype == cutlass.Float16
|
||||
assert kernel.metadata.operands.B.dtype == cutlass.Float16
|
||||
assert kernel.metadata.operands.out.dtype == cutlass.Float16
|
||||
assert kernel.metadata.operands.accumulator_type == cutlass.Float16
|
||||
833
python/cutlass_api/test/integration/test_gemm_epilogue_fusion.py
Normal file
833
python/cutlass_api/test/integration/test_gemm_epilogue_fusion.py
Normal file
@@ -0,0 +1,833 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import cutlass_api
|
||||
from cutlass_api.utils import device_cc
|
||||
from cutlass_api.config import GlobalOptions
|
||||
|
||||
torch.manual_seed(2025)
|
||||
|
||||
|
||||
def problem_sizes():
|
||||
"""
|
||||
Problem sizes for tests
|
||||
"""
|
||||
return [
|
||||
(256, 512, 1024, 1),
|
||||
(256, 512, 128, 1),
|
||||
(256, 512, 128, 2),
|
||||
]
|
||||
|
||||
|
||||
def base_data_types():
|
||||
"""
|
||||
Data types for (ab, c, d, accumulator)
|
||||
"""
|
||||
return [
|
||||
(torch.float16, torch.float32, torch.float32, torch.float32),
|
||||
(torch.float16, torch.float16, torch.float16, torch.float16),
|
||||
(torch.bfloat16, torch.bfloat16, torch.bfloat16, torch.float32),
|
||||
]
|
||||
|
||||
|
||||
def supports_sm100af():
|
||||
return device_cc() == 100 and (
|
||||
os.getenv("CUTE_DSL_ARCH", "") in ["", "sm_100a", "sm_100f"]
|
||||
)
|
||||
|
||||
|
||||
# Unary operation strings and functions
|
||||
identity = ("", lambda x: x)
|
||||
relu = ("relu", torch.relu)
|
||||
tanh = ("tanh", torch.tanh)
|
||||
sigmoid = ("sigmoid", torch.sigmoid)
|
||||
exp = ("exp", torch.exp)
|
||||
|
||||
unary_ops = [identity, relu, tanh, sigmoid, exp]
|
||||
|
||||
|
||||
# Binary operation strings and functions
|
||||
add = (lambda a, b: f"{a} + {b}", lambda a, b: a + b)
|
||||
sub = (lambda a, b: f"{a} - {b}", lambda a, b: a - b)
|
||||
mul = (lambda a, b: f"{a} * {b}", lambda a, b: a * b)
|
||||
|
||||
# Don't include divide in main binary ops due to issues with division by zero in refchecks
|
||||
binary_ops = [add, sub, mul]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M, N, K, L", problem_sizes())
|
||||
# Restrict to D of float16 for now to avoid rounding error when converting torch f16 output to f32
|
||||
@pytest.mark.parametrize(
|
||||
"ab_dtype, d_dtype, accumulator_type",
|
||||
[(torch.float16, torch.float16, torch.float16)],
|
||||
)
|
||||
@pytest.mark.parametrize("unary_str, unary_op", unary_ops)
|
||||
@pytest.mark.skipif(
|
||||
not supports_sm100af(),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
|
||||
)
|
||||
def test_gemm_fusion_unary(
|
||||
M, N, K, L, ab_dtype, d_dtype, accumulator_type, unary_str, unary_op
|
||||
):
|
||||
A = torch.randint(-1, 2, (L, M, K), device="cuda", dtype=ab_dtype)
|
||||
B = torch.randint(-1, 2, (L, K, N), device="cuda", dtype=ab_dtype)
|
||||
D = torch.empty((L, M, N), device="cuda", dtype=d_dtype)
|
||||
|
||||
def epi(accum):
|
||||
D = unary_op(accum)
|
||||
return D
|
||||
|
||||
epi_str = f"def epi(accum): D = {unary_str}(accum); return D"
|
||||
|
||||
epi_args = cutlass_api.arguments.EpilogueArguments(epi_str, D=D)
|
||||
args = cutlass_api.arguments.GemmArguments(
|
||||
A=A, B=B, out=D, accumulator_type=accumulator_type, epilogue=epi_args
|
||||
)
|
||||
kernels = cutlass_api.get_kernels(args, cc=100)
|
||||
|
||||
assert len(kernels) > 0
|
||||
kernels[0].run(args)
|
||||
|
||||
reference = epi(A @ B)
|
||||
torch.testing.assert_close(D, reference.to(D.dtype))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M, N, K, L", [(256, 512, 128, 2)])
|
||||
# Restrict to D of float16 for now to avoid rounding error when converting torch f16 output to f32
|
||||
@pytest.mark.parametrize(
|
||||
"ab_dtype, d_dtype, accumulator_type",
|
||||
[(torch.float16, torch.float16, torch.float16)],
|
||||
)
|
||||
@pytest.mark.parametrize("unary_str, unary_op", [relu])
|
||||
@pytest.mark.parametrize("unary_str2, unary_op2", [sigmoid, tanh])
|
||||
@pytest.mark.skipif(
|
||||
not supports_sm100af(),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
|
||||
)
|
||||
def test_gemm_fusion_unary_composition(
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
L,
|
||||
ab_dtype,
|
||||
d_dtype,
|
||||
accumulator_type,
|
||||
unary_str,
|
||||
unary_op,
|
||||
unary_str2,
|
||||
unary_op2,
|
||||
):
|
||||
A = torch.randint(-1, 2, (L, M, K), device="cuda", dtype=ab_dtype)
|
||||
B = torch.randint(-1, 2, (L, K, N), device="cuda", dtype=ab_dtype)
|
||||
D = torch.empty((L, M, N), device="cuda", dtype=d_dtype)
|
||||
|
||||
def epi(accum):
|
||||
D = unary_op2(unary_op(accum))
|
||||
return D
|
||||
|
||||
epi_str = f"def epi(accum): D = {unary_str2}({unary_str}(accum)); return D"
|
||||
|
||||
epi_args = cutlass_api.arguments.EpilogueArguments(epi_str, D=D)
|
||||
args = cutlass_api.arguments.GemmArguments(
|
||||
A=A, B=B, out=D, accumulator_type=accumulator_type, epilogue=epi_args
|
||||
)
|
||||
kernels = cutlass_api.get_kernels(args, cc=100)
|
||||
|
||||
assert len(kernels) > 0
|
||||
kernels[0].run(args)
|
||||
|
||||
reference = epi(A @ B)
|
||||
torch.testing.assert_close(D, reference.to(D.dtype))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M, N, K, L", problem_sizes())
|
||||
# Restrict to D of float16 for now to avoid rounding error when converting torch f16 output to f32
|
||||
@pytest.mark.parametrize(
|
||||
"ab_dtype, d_dtype, accumulator_type",
|
||||
[(torch.float16, torch.float16, torch.float16)],
|
||||
)
|
||||
# Restrict unary to identity and relu to avoid rounding errors
|
||||
@pytest.mark.parametrize("unary_str, unary_op", [identity, relu])
|
||||
@pytest.mark.skipif(
|
||||
not supports_sm100af(),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
|
||||
)
|
||||
def test_gemm_fusion_unary_literal(
|
||||
M, N, K, L, ab_dtype, d_dtype, accumulator_type, unary_str, unary_op
|
||||
):
|
||||
A = torch.randint(-1, 2, (L, M, K), device="cuda", dtype=ab_dtype)
|
||||
B = torch.randint(-1, 2, (L, K, N), device="cuda", dtype=ab_dtype)
|
||||
D = torch.empty((L, M, N), device="cuda", dtype=d_dtype)
|
||||
|
||||
def epi(accum):
|
||||
D = unary_op(accum) * 3.0 - 1.234
|
||||
return D
|
||||
|
||||
epi_str = f"def epi(accum): D = {unary_str}(accum) * 3.0 - 1.234; return D"
|
||||
|
||||
epi_args = cutlass_api.arguments.EpilogueArguments(epi_str, D=D)
|
||||
args = cutlass_api.arguments.GemmArguments(
|
||||
A=A, B=B, out=D, accumulator_type=accumulator_type, epilogue=epi_args
|
||||
)
|
||||
kernels = cutlass_api.get_kernels(args, cc=100)
|
||||
|
||||
assert len(kernels) > 0
|
||||
kernels[0].run(args)
|
||||
|
||||
reference = epi(A @ B)
|
||||
torch.testing.assert_close(D, reference.to(D.dtype))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M, N, K, L", problem_sizes())
|
||||
@pytest.mark.parametrize(
|
||||
"ab_dtype, c_dtype, d_dtype, accumulator_type", base_data_types()
|
||||
)
|
||||
@pytest.mark.parametrize("unary_str, unary_op", [identity, relu])
|
||||
@pytest.mark.parametrize("binary_str, binary_op", binary_ops)
|
||||
@pytest.mark.skipif(
|
||||
not supports_sm100af(),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
|
||||
)
|
||||
def test_gemm_fusion_unary_binary_composition(
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
L,
|
||||
ab_dtype,
|
||||
c_dtype,
|
||||
d_dtype,
|
||||
accumulator_type,
|
||||
unary_str,
|
||||
unary_op,
|
||||
binary_str,
|
||||
binary_op,
|
||||
):
|
||||
A = torch.randint(-1, 2, (L, M, K), device="cuda", dtype=ab_dtype)
|
||||
B = torch.randint(-1, 2, (L, K, N), device="cuda", dtype=ab_dtype)
|
||||
C = torch.randint(-1, 2, (L, M, N), device="cuda", dtype=c_dtype)
|
||||
D = torch.empty((L, M, N), device="cuda", dtype=d_dtype)
|
||||
|
||||
def epi(accum, C):
|
||||
z = unary_op(accum)
|
||||
D = binary_op(z, C)
|
||||
return D
|
||||
|
||||
epi_str = f"def epi(accum, C): z = {unary_str}(accum); D = {binary_str('z', 'C')}; return D"
|
||||
|
||||
epi_args = cutlass_api.arguments.EpilogueArguments(epi_str, C=C, D=D)
|
||||
|
||||
args = cutlass_api.arguments.GemmArguments(
|
||||
A=A, B=B, out=D, accumulator_type=accumulator_type, epilogue=epi_args
|
||||
)
|
||||
kernels = cutlass_api.get_kernels(args, cc=100)
|
||||
|
||||
assert len(kernels) > 0
|
||||
kernels[0].run(args)
|
||||
|
||||
reference = epi(A @ B, C)
|
||||
torch.testing.assert_close(D, reference.to(D.dtype))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M, N, K, L", problem_sizes())
|
||||
# Restrict to D of float16 for now to avoid rounding error when converting torch f16 output to f32
|
||||
@pytest.mark.parametrize(
|
||||
"ab_dtype, d_dtype, accumulator_type",
|
||||
[(torch.float16, torch.float16, torch.float16)],
|
||||
)
|
||||
@pytest.mark.parametrize("c0_dtype", [torch.float16, torch.float32])
|
||||
@pytest.mark.parametrize("c1_dtype", [torch.float16, torch.float32])
|
||||
@pytest.mark.parametrize("binary_str0, binary_op0", [add, sub])
|
||||
@pytest.mark.parametrize("binary_str1, binary_op1", [add, sub])
|
||||
@pytest.mark.skipif(
|
||||
not supports_sm100af(),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
|
||||
)
|
||||
def test_gemm_fusion_binary_binary_composition(
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
L,
|
||||
ab_dtype,
|
||||
d_dtype,
|
||||
accumulator_type,
|
||||
c0_dtype,
|
||||
c1_dtype,
|
||||
binary_str0,
|
||||
binary_op0,
|
||||
binary_str1,
|
||||
binary_op1,
|
||||
):
|
||||
A = torch.randint(-1, 2, (L, M, K), device="cuda", dtype=ab_dtype)
|
||||
B = torch.randint(-1, 2, (L, K, N), device="cuda", dtype=ab_dtype)
|
||||
C0 = torch.randint(-1, 2, (L, M, N), device="cuda", dtype=c0_dtype)
|
||||
C1 = torch.randint(-1, 2, (L, M, N), device="cuda", dtype=c1_dtype)
|
||||
D = torch.empty((L, M, N), device="cuda", dtype=d_dtype)
|
||||
|
||||
def epi(accum, C0, C1):
|
||||
z = torch.relu(accum)
|
||||
z1 = binary_op0(z, C0)
|
||||
D = binary_op1(z1, C1)
|
||||
return D
|
||||
|
||||
epi_str = f"def epi(accum, C0, C1): z = relu(accum); z1 = {binary_str0('z', 'C0')}; D = {binary_str1('z1', 'C1')}; return D"
|
||||
|
||||
epi_args = cutlass_api.arguments.EpilogueArguments(epi_str, C0=C0, C1=C1, D=D)
|
||||
|
||||
args = cutlass_api.arguments.GemmArguments(
|
||||
A=A, B=B, out=D, accumulator_type=accumulator_type, epilogue=epi_args
|
||||
)
|
||||
kernels = cutlass_api.get_kernels(args, cc=100)
|
||||
|
||||
assert len(kernels) > 0
|
||||
kernels[0].run(args)
|
||||
|
||||
reference = epi(A @ B, C0, C1)
|
||||
torch.testing.assert_close(D, reference.to(D.dtype))
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not supports_sm100af(),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
|
||||
)
|
||||
def test_gemm_fusion_division():
|
||||
M, N, K, L = 256, 512, 128, 2
|
||||
ab_dtype = torch.float16
|
||||
d_dtype = torch.float16
|
||||
accumulator_type = torch.float16
|
||||
|
||||
# Specifically initialize A and B with ones to avoid division by zero in refchecks
|
||||
A = torch.ones((L, M, K), device="cuda", dtype=ab_dtype)
|
||||
B = torch.ones((L, K, N), device="cuda", dtype=ab_dtype)
|
||||
D = torch.empty((L, M, N), device="cuda", dtype=d_dtype)
|
||||
|
||||
#########################################################
|
||||
# Test division by a literal
|
||||
#########################################################
|
||||
def epi(accum):
|
||||
D = accum / 2.0
|
||||
return D
|
||||
|
||||
epi_str = "def epi(accum): D = accum / 2.0; return D"
|
||||
epi_args = cutlass_api.arguments.EpilogueArguments(epi_str, D=D)
|
||||
|
||||
args = cutlass_api.arguments.GemmArguments(
|
||||
A=A, B=B, out=D, accumulator_type=accumulator_type, epilogue=epi_args
|
||||
)
|
||||
kernels = cutlass_api.get_kernels(args, cc=100)
|
||||
|
||||
assert len(kernels) > 0
|
||||
kernels[0].run(args)
|
||||
|
||||
reference = epi(A @ B)
|
||||
torch.testing.assert_close(D, reference.to(D.dtype))
|
||||
|
||||
#########################################################
|
||||
# Test division by an input
|
||||
#########################################################
|
||||
def epi(accum, scalar):
|
||||
D = accum / scalar
|
||||
return D
|
||||
|
||||
epi_str = "def epi(accum, scalar): D = accum / scalar; return D"
|
||||
|
||||
scalar = 4.0
|
||||
epi_args = cutlass_api.arguments.EpilogueArguments(epi_str, scalar=scalar, D=D)
|
||||
args = cutlass_api.arguments.GemmArguments(
|
||||
A=A, B=B, out=D, accumulator_type=accumulator_type, epilogue=epi_args
|
||||
)
|
||||
kernels = cutlass_api.get_kernels(args, cc=100)
|
||||
|
||||
assert len(kernels) > 0
|
||||
kernels[0].run(args)
|
||||
|
||||
reference = epi(A @ B, scalar)
|
||||
|
||||
torch.testing.assert_close(D, reference.to(D.dtype))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M, N, K, L", problem_sizes())
|
||||
# Restrict to D of float16 for now to avoid rounding error when converting torch f16 output to f32
|
||||
@pytest.mark.parametrize(
|
||||
"ab_dtype, d_dtype, accumulator_type",
|
||||
[(torch.float16, torch.float16, torch.float16)],
|
||||
)
|
||||
@pytest.mark.parametrize("unary_str, unary_op", [sigmoid, tanh])
|
||||
@pytest.mark.skipif(
|
||||
not supports_sm100af(),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
|
||||
)
|
||||
def test_gemm_fusion_unary_multi_output(
|
||||
M, N, K, L, ab_dtype, d_dtype, accumulator_type, unary_str, unary_op
|
||||
):
|
||||
A = torch.randint(-1, 2, (L, M, K), device="cuda", dtype=ab_dtype)
|
||||
B = torch.randint(-1, 2, (L, K, N), device="cuda", dtype=ab_dtype)
|
||||
D = torch.empty((L, M, N), device="cuda", dtype=d_dtype)
|
||||
z = torch.empty((L, M, N), device="cuda", dtype=d_dtype)
|
||||
|
||||
def epi(accum):
|
||||
z0 = torch.relu(accum)
|
||||
z = unary_op(z0)
|
||||
D = z + z0
|
||||
return D, z
|
||||
|
||||
epi_str = f"def epi(accum): z0 = relu(accum); z = {unary_str}(z0); D = z + z0; return D, z"
|
||||
|
||||
epi_args = cutlass_api.arguments.EpilogueArguments(epi_str, z=z, D=D)
|
||||
|
||||
args = cutlass_api.arguments.GemmArguments(
|
||||
A=A, B=B, out=D, accumulator_type=accumulator_type, epilogue=epi_args
|
||||
)
|
||||
kernels = cutlass_api.get_kernels(args, cc=100)
|
||||
|
||||
assert len(kernels) > 0
|
||||
kernels[0].run(args)
|
||||
|
||||
ref_D, ref_z = epi(A @ B)
|
||||
torch.testing.assert_close(D, ref_D.to(D.dtype))
|
||||
torch.testing.assert_close(z, ref_z.to(z.dtype))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M, N, K, L", problem_sizes())
|
||||
# Restrict to D of float16 for now to avoid rounding error when converting torch f16 output to f32
|
||||
@pytest.mark.parametrize(
|
||||
"ab_dtype, d_dtype, accumulator_type",
|
||||
[(torch.float16, torch.float16, torch.float16)],
|
||||
)
|
||||
@pytest.mark.parametrize("c_dtype", [torch.float16, torch.float32])
|
||||
@pytest.mark.parametrize("binary_str, binary_op", binary_ops)
|
||||
@pytest.mark.skipif(
|
||||
not supports_sm100af(),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
|
||||
)
|
||||
def test_gemm_fusion_binary_multi_output(
|
||||
M, N, K, L, ab_dtype, d_dtype, accumulator_type, c_dtype, binary_str, binary_op
|
||||
):
|
||||
A = torch.randint(-1, 2, (L, M, K), device="cuda", dtype=ab_dtype)
|
||||
B = torch.randint(-1, 2, (L, K, N), device="cuda", dtype=ab_dtype)
|
||||
C = torch.randint(-1, 2, (L, M, N), device="cuda", dtype=c_dtype)
|
||||
D = torch.empty((L, M, N), device="cuda", dtype=d_dtype)
|
||||
z = torch.empty((L, M, N), device="cuda", dtype=d_dtype)
|
||||
|
||||
def epi(accum, C):
|
||||
z0 = torch.relu(accum)
|
||||
z = binary_op(z0, C)
|
||||
D = z + z0
|
||||
return D, z
|
||||
|
||||
epi_str = f"def epi(accum, C): z0 = relu(accum); z = {binary_str('z0', 'C')}; D = z + z0; return D, z"
|
||||
|
||||
epi_args = cutlass_api.arguments.EpilogueArguments(epi_str, C=C, z=z, D=D)
|
||||
|
||||
args = cutlass_api.arguments.GemmArguments(
|
||||
A=A, B=B, out=D, accumulator_type=accumulator_type, epilogue=epi_args
|
||||
)
|
||||
kernels = cutlass_api.get_kernels(args, cc=100)
|
||||
|
||||
assert len(kernels) > 0
|
||||
kernels[0].run(args)
|
||||
|
||||
ref_D, ref_z = epi(A @ B, C)
|
||||
torch.testing.assert_close(D, ref_D.to(D.dtype))
|
||||
torch.testing.assert_close(z, ref_z.to(z.dtype))
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not supports_sm100af(),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
|
||||
)
|
||||
def test_gemm_fusion_return_acc():
|
||||
M, N, K, L = 256, 512, 128, 2
|
||||
ab_dtype = torch.float16
|
||||
c_dtype = torch.float32
|
||||
d_dtype = torch.float16
|
||||
accumulator_type = torch.float16
|
||||
|
||||
A = torch.randint(-1, 2, (L, M, K), device="cuda", dtype=ab_dtype)
|
||||
B = torch.randint(-1, 2, (L, K, N), device="cuda", dtype=ab_dtype)
|
||||
C = torch.randint(-1, 2, (L, M, N), device="cuda", dtype=c_dtype)
|
||||
D = torch.empty((L, M, N), device="cuda", dtype=d_dtype)
|
||||
accum = torch.empty((L, M, N), device="cuda", dtype=accumulator_type)
|
||||
|
||||
def epi(accum, C):
|
||||
D = torch.relu(accum) + C
|
||||
return D, accum
|
||||
|
||||
epi_str = "def epi(accum, C): D = relu(accum) + C; return D, accum"
|
||||
|
||||
epi_args = cutlass_api.arguments.EpilogueArguments(
|
||||
epi_str, C=C, D=D, accum=accum
|
||||
)
|
||||
|
||||
args = cutlass_api.arguments.GemmArguments(
|
||||
A=A, B=B, out=D, accumulator_type=accumulator_type, epilogue=epi_args
|
||||
)
|
||||
kernels = cutlass_api.get_kernels(args, cc=100)
|
||||
|
||||
assert len(kernels) > 0
|
||||
kernels[0].run(args)
|
||||
|
||||
reference, ref_accum = epi(A @ B, C)
|
||||
torch.testing.assert_close(D, reference.to(D.dtype))
|
||||
torch.testing.assert_close(accum, ref_accum.to(accum.dtype))
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not supports_sm100af(),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
|
||||
)
|
||||
def test_gemm_fusion_acc_as_multiple_input():
|
||||
M, N, K, L = 256, 512, 128, 2
|
||||
ab_dtype = torch.float16
|
||||
c_dtype = torch.float32
|
||||
d_dtype = torch.float16
|
||||
accumulator_type = torch.float16
|
||||
|
||||
A = torch.randint(-1, 2, (L, M, K), device="cuda", dtype=ab_dtype)
|
||||
B = torch.randint(-1, 2, (L, K, N), device="cuda", dtype=ab_dtype)
|
||||
C = torch.randint(-1, 2, (L, M, N), device="cuda", dtype=c_dtype)
|
||||
D = torch.empty((L, M, N), device="cuda", dtype=d_dtype)
|
||||
|
||||
#########################################################
|
||||
# Test binary op inside
|
||||
#########################################################
|
||||
def epi(accum, C):
|
||||
D = torch.relu(torch.relu(accum) * C) + accum
|
||||
return D
|
||||
|
||||
epi_str = "def epi(accum, C): D = relu(relu(accum) * C) + accum; return D"
|
||||
|
||||
epi_args = cutlass_api.arguments.EpilogueArguments(epi_str, C=C, D=D)
|
||||
|
||||
args = cutlass_api.arguments.GemmArguments(
|
||||
A=A, B=B, out=D, accumulator_type=accumulator_type, epilogue=epi_args
|
||||
)
|
||||
kernels = cutlass_api.get_kernels(args, cc=100)
|
||||
|
||||
assert len(kernels) > 0
|
||||
kernels[0].run(args)
|
||||
|
||||
reference = epi(A @ B, C)
|
||||
torch.testing.assert_close(D, reference.to(D.dtype))
|
||||
|
||||
#########################################################
|
||||
# Test unary op inside
|
||||
#########################################################
|
||||
def epi(accum):
|
||||
D = torch.relu(torch.sigmoid(torch.relu(accum))) + accum
|
||||
return D
|
||||
|
||||
epi_str = "def epi(accum): D = relu(sigmoid(relu(accum))) + accum; return D"
|
||||
|
||||
epi_args = cutlass_api.arguments.EpilogueArguments(epi_str, D=D)
|
||||
|
||||
args = cutlass_api.arguments.GemmArguments(
|
||||
A=A, B=B, out=D, accumulator_type=accumulator_type, epilogue=epi_args
|
||||
)
|
||||
kernels = cutlass_api.get_kernels(args, cc=100)
|
||||
|
||||
assert len(kernels) > 0
|
||||
kernels[0].run(args)
|
||||
|
||||
reference = epi(A @ B)
|
||||
torch.testing.assert_close(D, reference.to(D.dtype))
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not supports_sm100af(),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
|
||||
)
|
||||
def test_gemm_fusion_matmul_input_as_aux():
|
||||
M, N, K, L = 1024, 1024, 1024, 2
|
||||
ab_dtype = torch.float16
|
||||
c_dtype = torch.float32
|
||||
d_dtype = torch.float16
|
||||
accumulator_type = torch.float16
|
||||
|
||||
A = torch.randint(-1, 2, (L, M, K), device="cuda", dtype=ab_dtype)
|
||||
B = torch.randint(-1, 2, (L, K, N), device="cuda", dtype=ab_dtype)
|
||||
C = torch.randint(-1, 2, (L, M, N), device="cuda", dtype=c_dtype)
|
||||
D = torch.empty((L, M, N), device="cuda", dtype=d_dtype)
|
||||
|
||||
#########################################################
|
||||
# Test binary op inside
|
||||
#########################################################
|
||||
def epi(accum, C, A):
|
||||
D = torch.sigmoid(torch.relu(accum) * C) + A
|
||||
return D
|
||||
|
||||
epi_str = "def epi(accum, C, A): D = sigmoid(relu(accum) * C) + A; return D"
|
||||
|
||||
epi_args = cutlass_api.arguments.EpilogueArguments(epi_str, C=C, A=A, D=D)
|
||||
|
||||
args = cutlass_api.arguments.GemmArguments(
|
||||
A=A, B=B, out=D, accumulator_type=accumulator_type, epilogue=epi_args
|
||||
)
|
||||
kernels = cutlass_api.get_kernels(args, cc=100)
|
||||
|
||||
assert len(kernels) > 0
|
||||
kernels[0].run(args)
|
||||
|
||||
reference = epi(A @ B, C, A)
|
||||
torch.testing.assert_close(D, reference.to(D.dtype))
|
||||
|
||||
#########################################################
|
||||
# Test unary op inside
|
||||
#########################################################
|
||||
def epi(accum, A):
|
||||
D = torch.tanh(torch.relu(accum)) + A
|
||||
return D
|
||||
|
||||
epi_str = "def epi(accum, A): D = tanh(relu(accum)) + A; return D"
|
||||
|
||||
epi_args = cutlass_api.arguments.EpilogueArguments(epi_str, A=A, D=D)
|
||||
|
||||
args = cutlass_api.arguments.GemmArguments(
|
||||
A=A, B=B, out=D, accumulator_type=accumulator_type, epilogue=epi_args
|
||||
)
|
||||
kernels = cutlass_api.get_kernels(args, cc=100)
|
||||
|
||||
assert len(kernels) > 0
|
||||
kernels[0].run(args)
|
||||
|
||||
reference = epi(A @ B, A)
|
||||
torch.testing.assert_close(D, reference.to(D.dtype))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M, N, K, L", problem_sizes())
|
||||
@pytest.mark.parametrize(
|
||||
"ab_dtype, c_dtype, d_dtype, accumulator_type", base_data_types()
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"use_tvm_ffi",
|
||||
[True, False],
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not supports_sm100af(),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
|
||||
)
|
||||
def test_gemm_alpha_beta(
|
||||
M, N, K, L, ab_dtype, c_dtype, d_dtype, accumulator_type, use_tvm_ffi
|
||||
):
|
||||
GlobalOptions().use_tvm_ffi = use_tvm_ffi
|
||||
|
||||
A = torch.randint(-1, 2, (L, M, K), device="cuda", dtype=ab_dtype)
|
||||
B = torch.randint(-1, 2, (L, K, N), device="cuda", dtype=ab_dtype)
|
||||
C = torch.randint(-1, 2, (L, M, N), device="cuda", dtype=c_dtype)
|
||||
D = torch.empty((L, M, N), device="cuda", dtype=d_dtype)
|
||||
|
||||
def epi(accum, C, alpha, beta):
|
||||
D = alpha * accum + beta * C
|
||||
return D
|
||||
|
||||
alpha = 0.5
|
||||
beta = 0.5
|
||||
epi_args = cutlass_api.arguments.EpilogueArguments(
|
||||
epi, C=C, alpha=alpha, beta=beta, D=D
|
||||
)
|
||||
|
||||
args = cutlass_api.arguments.GemmArguments(
|
||||
A=A, B=B, out=D, accumulator_type=accumulator_type, epilogue=epi_args
|
||||
)
|
||||
kernels = cutlass_api.get_kernels(args, cc=100)
|
||||
|
||||
assert len(kernels) > 0
|
||||
kernel = kernels[0]
|
||||
|
||||
for a, b in [(0.5, 0.5), (1.0, 0.0), (0.0, 1.0)]:
|
||||
epi_args = cutlass_api.arguments.EpilogueArguments(
|
||||
epi, C=C, alpha=a, beta=b, D=D
|
||||
)
|
||||
args = cutlass_api.arguments.GemmArguments(
|
||||
A=A, B=B, out=D, accumulator_type=accumulator_type, epilogue=epi_args
|
||||
)
|
||||
kernel.run(args)
|
||||
reference = epi(A @ B, C, a, b)
|
||||
torch.testing.assert_close(D, reference.to(D.dtype))
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not supports_sm100af(),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
|
||||
)
|
||||
def test_gemm_big_epi():
|
||||
M, N, K, L = 256, 512, 128, 2
|
||||
ab_dtype = torch.float16
|
||||
c_dtype = torch.float32
|
||||
d_dtype = torch.bfloat16
|
||||
accumulator_type = torch.float16
|
||||
|
||||
A = torch.randint(-1, 2, (L, M, K), device="cuda", dtype=ab_dtype)
|
||||
B = torch.randint(-1, 2, (L, K, N), device="cuda", dtype=ab_dtype)
|
||||
|
||||
In0 = torch.randint(-1, 2, (L, M, N), device="cuda", dtype=c_dtype)
|
||||
In1 = torch.randint(-1, 2, (L, M, N), device="cuda", dtype=c_dtype)
|
||||
In2 = torch.randint(-1, 2, (L, M, N), device="cuda", dtype=c_dtype)
|
||||
In3 = torch.randint(-1, 2, (L, M, N), device="cuda", dtype=c_dtype)
|
||||
In4 = torch.randint(-1, 2, (L, M, N), device="cuda", dtype=c_dtype)
|
||||
In5 = torch.randint(-1, 2, (L, M, N), device="cuda", dtype=c_dtype)
|
||||
In6 = torch.randint(-1, 2, (L, M, N), device="cuda", dtype=c_dtype)
|
||||
In7 = torch.randint(-1, 2, (L, M, N), device="cuda", dtype=c_dtype)
|
||||
|
||||
Out0 = torch.empty((L, M, N), device="cuda", dtype=d_dtype)
|
||||
Out1 = torch.empty((L, M, N), device="cuda", dtype=d_dtype)
|
||||
Out2 = torch.empty((L, M, N), device="cuda", dtype=d_dtype)
|
||||
Out3 = torch.empty((L, M, N), device="cuda", dtype=d_dtype)
|
||||
Out4 = torch.empty((L, M, N), device="cuda", dtype=d_dtype)
|
||||
Out5 = torch.empty((L, M, N), device="cuda", dtype=d_dtype)
|
||||
Out6 = torch.empty((L, M, N), device="cuda", dtype=d_dtype)
|
||||
D = torch.empty((L, M, N), device="cuda", dtype=d_dtype)
|
||||
|
||||
sc0 = 1.0
|
||||
sc1 = 2.0
|
||||
sc2 = 3.0
|
||||
sc3 = 4.0
|
||||
sc4 = 5.0
|
||||
sc5 = 6.0
|
||||
sc6 = 7.0
|
||||
sc7 = 8.0
|
||||
|
||||
def epi(
|
||||
accum,
|
||||
In0,
|
||||
In1,
|
||||
In2,
|
||||
In3,
|
||||
In4,
|
||||
In5,
|
||||
In6,
|
||||
In7,
|
||||
sc0,
|
||||
sc1,
|
||||
sc2,
|
||||
sc3,
|
||||
sc4,
|
||||
sc5,
|
||||
sc6,
|
||||
sc7,
|
||||
):
|
||||
Out0 = accum * sc0 + In0
|
||||
Out1 = Out0 + In1 * sc1
|
||||
Out2 = Out1 - In2 * sc2
|
||||
Out3 = Out2 + In3 * sc3
|
||||
Out4 = Out3 - In4 * sc4
|
||||
Out5 = Out4 + In5 * sc5
|
||||
Out6 = Out5 - In6 * sc6
|
||||
D = Out6 + In7 * sc7
|
||||
return Out0, Out1, Out2, Out3, Out4, Out5, Out6, D
|
||||
|
||||
epi_args = cutlass_api.arguments.EpilogueArguments(
|
||||
epi,
|
||||
In0=In0,
|
||||
In1=In1,
|
||||
In2=In2,
|
||||
In3=In3,
|
||||
In4=In4,
|
||||
In5=In5,
|
||||
In6=In6,
|
||||
In7=In7,
|
||||
Out0=Out0,
|
||||
Out1=Out1,
|
||||
Out2=Out2,
|
||||
Out3=Out3,
|
||||
Out4=Out4,
|
||||
Out5=Out5,
|
||||
Out6=Out6,
|
||||
D=D,
|
||||
sc0=sc0,
|
||||
sc1=sc1,
|
||||
sc2=sc2,
|
||||
sc3=sc3,
|
||||
sc4=sc4,
|
||||
sc5=sc5,
|
||||
sc6=sc6,
|
||||
sc7=sc7,
|
||||
)
|
||||
args = cutlass_api.arguments.GemmArguments(
|
||||
A=A, B=B, out=D, accumulator_type=accumulator_type, epilogue=epi_args
|
||||
)
|
||||
kernels = cutlass_api.get_kernels(args, cc=100)
|
||||
|
||||
assert len(kernels) > 0
|
||||
kernels[0].run(args)
|
||||
|
||||
reference = epi(
|
||||
A @ B,
|
||||
In0,
|
||||
In1,
|
||||
In2,
|
||||
In3,
|
||||
In4,
|
||||
In5,
|
||||
In6,
|
||||
In7,
|
||||
sc0,
|
||||
sc1,
|
||||
sc2,
|
||||
sc3,
|
||||
sc4,
|
||||
sc5,
|
||||
sc6,
|
||||
sc7,
|
||||
)
|
||||
|
||||
for out, ref in zip([Out0, Out1, Out2, Out3, Out4, Out5, Out6, D], reference):
|
||||
torch.testing.assert_close(out, ref.to(out.dtype))
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not supports_sm100af(),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
|
||||
)
|
||||
def test_gemm_fusion_not_available():
|
||||
M = 256
|
||||
N = 512
|
||||
K = 1024
|
||||
A = torch.randint(-1, 2, (M, K), device="cuda", dtype=torch.float16)
|
||||
B = torch.randint(-1, 2, (K, N), device="cuda", dtype=torch.float16)
|
||||
D = torch.empty((M, N), device="cuda", dtype=torch.float16)
|
||||
|
||||
# Non-scalar broadcasts are currently not supported
|
||||
bias = torch.randint(-1, 2, (M, 1), device="cuda", dtype=torch.float16)
|
||||
|
||||
def epi(accum, bias):
|
||||
D = accum + bias
|
||||
return D
|
||||
|
||||
epi_args = cutlass_api.arguments.EpilogueArguments(epi, bias=bias, D=D)
|
||||
args = cutlass_api.arguments.GemmArguments(
|
||||
A=A, B=B, out=D, accumulator_type=torch.float16, epilogue=epi_args
|
||||
)
|
||||
kernels = cutlass_api.get_kernels(args, cc=100)
|
||||
|
||||
assert len(kernels) == 0
|
||||
86
python/cutlass_api/test/integration/test_notebooks.py
Normal file
86
python/cutlass_api/test/integration/test_notebooks.py
Normal file
@@ -0,0 +1,86 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
import cutlass_api
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"notebook_name, supported_ccs",
|
||||
[
|
||||
("000_gemm.ipynb", [100, 103]),
|
||||
("001_gemm_with_fused_epilogue.ipynb", [100, 103]),
|
||||
("002_bring_your_own_kernel.ipynb", [80, 89, 90, 100, 103, 120, 121]),
|
||||
("003_host_latency_best_practices.ipynb", [100, 103]),
|
||||
],
|
||||
)
|
||||
def test_notebooks(notebook_name, supported_ccs):
|
||||
possible_cc_strs = []
|
||||
for cc in supported_ccs:
|
||||
arch_conditional_ccs = [90, 100, 101, 103, 120, 121]
|
||||
family_conditional_ccs = [100, 103]
|
||||
|
||||
if cc in family_conditional_ccs:
|
||||
possible_cc_strs.extend([f"sm_{cc}f", f"sm_{cc}a"])
|
||||
elif cc in arch_conditional_ccs:
|
||||
possible_cc_strs.append(f"sm_{cc}a")
|
||||
else:
|
||||
possible_cc_strs.append(f"sm_{cc}")
|
||||
|
||||
cute_dsl_arch = os.environ.get("CUTE_DSL_ARCH", "")
|
||||
|
||||
# Add empty string to allow running without macro set
|
||||
possible_cc_strs.append("")
|
||||
|
||||
if (
|
||||
cutlass_api.utils.device_cc() not in supported_ccs
|
||||
or cute_dsl_arch not in possible_cc_strs
|
||||
):
|
||||
# Each test should gracefully exit(0) for CCs that are not supported, but
|
||||
# the nbconvert-based runner below has issues with this. Thus, we manually
|
||||
# skip here.
|
||||
pytest.skip(
|
||||
f"This notebook requires a GPU with compute capability {supported_ccs}"
|
||||
)
|
||||
|
||||
notebook_dir = os.path.join(os.path.dirname(__file__), "..", "..", "examples")
|
||||
full_notebook_path = os.path.join(notebook_dir, notebook_name)
|
||||
|
||||
import nbconvert
|
||||
import nbformat
|
||||
|
||||
with open(full_notebook_path, "r") as file:
|
||||
notebook = nbformat.read(file, as_version=4)
|
||||
ep = nbconvert.preprocessors.ExecutePreprocessor(timeout=600, kernel_name="python3")
|
||||
|
||||
# Execute the notebook. This call will error out on any assertions or errors in the
|
||||
# notebook itself. Allow these to propagate up so the test will fail on notebook failure.
|
||||
ep.preprocess(notebook, {"metadata": {"path": notebook_dir}})
|
||||
178
python/cutlass_api/test/integration/test_tvm_ffi.py
Normal file
178
python/cutlass_api/test/integration/test_tvm_ffi.py
Normal file
@@ -0,0 +1,178 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
import os
|
||||
|
||||
import cuda.bindings.driver as cuda
|
||||
import pytest
|
||||
import torch
|
||||
from torch.cuda import current_stream
|
||||
|
||||
import cutlass_api
|
||||
from cutlass_api.config import GlobalOptions
|
||||
from cutlass_api.utils import is_device_cc_supported
|
||||
|
||||
|
||||
def benchmark(label, code, n_iterations):
|
||||
start = torch.cuda.Event(enable_timing=True)
|
||||
end = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
for _ in range(n_iterations):
|
||||
code()
|
||||
|
||||
start.record()
|
||||
for _ in range(n_iterations):
|
||||
code()
|
||||
end.record()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
avg_time = start.elapsed_time(end) / n_iterations
|
||||
print(f"[{label:20}] avg of {n_iterations} iterations: {avg_time:1.3e} ms")
|
||||
return avg_time
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"M, N, K",
|
||||
[
|
||||
(256, 512, 1024),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"ab_dtype, c_dtype, accumulator_type",
|
||||
[
|
||||
(torch.float16, torch.float16, torch.float16),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"n_iterations",
|
||||
[
|
||||
50,
|
||||
],
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not is_device_cc_supported({100})
|
||||
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_100f"]),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
|
||||
)
|
||||
def test_gemm_sm100(
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
ab_dtype: torch.dtype,
|
||||
c_dtype: torch.dtype,
|
||||
accumulator_type: torch.dtype,
|
||||
n_iterations: int,
|
||||
):
|
||||
print()
|
||||
|
||||
A = torch.randint(-1, 2, (M, K), device="cuda").to(ab_dtype)
|
||||
B = torch.randint(-1, 2, (K, N), device="cuda").to(ab_dtype)
|
||||
D = torch.empty((M, N), device="cuda").to(c_dtype)
|
||||
|
||||
GlobalOptions().use_tvm_ffi = True
|
||||
|
||||
args = cutlass_api.arguments.GemmArguments(
|
||||
A=A, B=B, out=D, accumulator_type=accumulator_type
|
||||
)
|
||||
kernels = cutlass_api.get_kernels(args, cc=100)
|
||||
|
||||
assert len(kernels) > 0
|
||||
kernel = kernels[0]
|
||||
|
||||
"""
|
||||
Compile & run the kernel with TVM-FFI.
|
||||
"""
|
||||
|
||||
assert kernel.supports(args)
|
||||
compiled_artifact_with_tvm_ffi = kernel.compile(args)
|
||||
kernel.run(
|
||||
args,
|
||||
compiled_artifact=compiled_artifact_with_tvm_ffi,
|
||||
stream=current_stream(),
|
||||
assume_supported_args=True,
|
||||
)
|
||||
reference = A @ B
|
||||
assert torch.allclose(D, reference.to(D.dtype)), "Refcheck failed!"
|
||||
|
||||
# Also works with CUDA graphs
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g):
|
||||
kernel.run(
|
||||
args,
|
||||
compiled_artifact=compiled_artifact_with_tvm_ffi,
|
||||
stream=current_stream(),
|
||||
assume_supported_args=True,
|
||||
)
|
||||
|
||||
D.zero_()
|
||||
g.replay()
|
||||
torch.cuda.synchronize()
|
||||
assert torch.allclose(D, reference.to(D.dtype)), "Refcheck failed!"
|
||||
|
||||
"""
|
||||
Create args & run with & without TVM-FFI to compare overhead
|
||||
"""
|
||||
|
||||
args = cutlass_api.arguments.GemmArguments(
|
||||
A=A, B=B, out=D, accumulator_type=accumulator_type
|
||||
)
|
||||
compiled_artifact_with_tvm_ffi = kernel.compile(args)
|
||||
|
||||
# Run with TVM-FFI and time it
|
||||
avg_time_with_tvm_ffi = benchmark(
|
||||
"Run with TVM-FFI",
|
||||
lambda: kernel.run(
|
||||
args,
|
||||
compiled_artifact=compiled_artifact_with_tvm_ffi,
|
||||
stream=current_stream(),
|
||||
assume_supported_args=True,
|
||||
),
|
||||
n_iterations,
|
||||
)
|
||||
|
||||
# Run without TVM-FFI and time it
|
||||
GlobalOptions().use_tvm_ffi = False
|
||||
args_without_tvm_ffi = cutlass_api.arguments.GemmArguments(
|
||||
A=A, B=B, out=D, accumulator_type=accumulator_type
|
||||
)
|
||||
compiled_artifact_without_tvm_ffi = kernel.compile(args_without_tvm_ffi)
|
||||
stream = cuda.CUstream(current_stream().cuda_stream)
|
||||
avg_time_without_tvm_ffi = benchmark(
|
||||
"Run without TVM-FFI",
|
||||
lambda: kernel.run(
|
||||
args_without_tvm_ffi,
|
||||
compiled_artifact=compiled_artifact_without_tvm_ffi,
|
||||
stream=stream,
|
||||
assume_supported_args=True,
|
||||
),
|
||||
n_iterations,
|
||||
)
|
||||
|
||||
speedup = avg_time_without_tvm_ffi / avg_time_with_tvm_ffi
|
||||
print(f"Speedup with TVM-FFI: {speedup:.3f}x")
|
||||
@@ -0,0 +1,95 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import cutlass_api
|
||||
|
||||
|
||||
pytestmark = pytest.mark.arch("80")
|
||||
|
||||
|
||||
def test_basic_keywords():
|
||||
def epi(accum, C, alpha, beta):
|
||||
D = (alpha * accum) + (beta * C)
|
||||
return D
|
||||
|
||||
cutlass_api.arguments.EpilogueArguments(
|
||||
epilogue_fn=epi,
|
||||
D=torch.randn(10, 10),
|
||||
C=torch.randn(10, 10),
|
||||
alpha=1.0,
|
||||
beta=1.0,
|
||||
)
|
||||
|
||||
|
||||
def test_missing_keywords():
|
||||
epifn = """
|
||||
def epi(accum, C, alpha, beta):
|
||||
D = (alpha * accum) + (beta * C)
|
||||
F = relu(D)
|
||||
return D, F
|
||||
"""
|
||||
try:
|
||||
# Missing F
|
||||
cutlass_api.arguments.EpilogueArguments(
|
||||
epilogue_fn=epifn,
|
||||
D=torch.randn(10, 10),
|
||||
C=torch.randn(10, 10),
|
||||
alpha=1.0,
|
||||
beta=1.0,
|
||||
)
|
||||
except ValueError as e:
|
||||
assert "F" in str(e)
|
||||
else:
|
||||
assert False, "Failed to catch missing keyword"
|
||||
|
||||
|
||||
def test_extra_keywords():
|
||||
epifn = """
|
||||
def epi(accum, C, alpha, beta):
|
||||
D = (alpha * accum) + (beta * C)
|
||||
F = relu(D)
|
||||
return D, F
|
||||
"""
|
||||
try:
|
||||
cutlass_api.arguments.EpilogueArguments(
|
||||
epilogue_fn=epifn,
|
||||
D=torch.randn(10, 10),
|
||||
C=torch.randn(10, 10),
|
||||
alpha=1.0,
|
||||
beta=1.0,
|
||||
F=torch.randn(10, 10),
|
||||
gamma=3.0,
|
||||
)
|
||||
except ValueError as e:
|
||||
assert "gamma" in str(e)
|
||||
else:
|
||||
assert False, "Failed to catch extra keyword"
|
||||
Reference in New Issue
Block a user