mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-19 22:38:56 +00:00
2026-01-06 updates
This commit is contained in:
@@ -69,7 +69,7 @@ See CUTLASS's [Compatibility section](https://github.com/NVIDIA/cutlass?tab=read
|
||||
### Current support
|
||||
|
||||
* Dense GEMM: `out = A @ B`
|
||||
- Compute capabilities: 100, 103
|
||||
- Compute capabilities: 100, 103 (WIP to expand to more)
|
||||
- Input precisions (A and B must be of same type): F16, BF16, TF32, INT8
|
||||
- Output precisions: F32, F16, BF16, INT32
|
||||
- Epilogue operations:
|
||||
@@ -78,6 +78,7 @@ See CUTLASS's [Compatibility section](https://github.com/NVIDIA/cutlass?tab=read
|
||||
- Auxiliary load of scalar
|
||||
- Tensor-tensor elementwise or tensor-scalar addition, multiplication, subtraction, division
|
||||
- Elementwise tensor exponent, relu, sigmoid, tanh
|
||||
- Note: Partial support exists on CC 80/89/90 (limited dtypes/tilings coverage)
|
||||
|
||||
* Planned additions
|
||||
* Block-scaled GEMMs
|
||||
|
||||
@@ -60,9 +60,16 @@ class PerformanceControls:
|
||||
|
||||
|
||||
class EpilogueArguments:
|
||||
def __init__(self, epilogue_fn: Callable | str | None = None, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
epilogue_fn: Callable | str,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Encapsulation of the epilogue function and its arguments needed to
|
||||
Describes a user-defined epilogue that is performed on top of the operation
|
||||
described by the primary `RuntimeArguments`.
|
||||
|
||||
It encapsulates an epilogue function and its arguments needed to
|
||||
determine the functional operation of an epilogue pattern.
|
||||
|
||||
To support flexible definition of epilogues, `EpilogueArguments` is
|
||||
@@ -120,7 +127,10 @@ class EpilogueArguments:
|
||||
```
|
||||
A user would need to construct epilogue arguments as follows:
|
||||
```python
|
||||
epi_args = EpilogueArguments(my_epi, alpha=..., C=..., beta=..., D=..., F=...)
|
||||
epi_args = EpilogueArguments(
|
||||
my_epi,
|
||||
alpha=..., C=..., beta=..., D=..., F=...
|
||||
)
|
||||
```
|
||||
|
||||
:param epilogue_fn: The epilogue function to be traced.
|
||||
@@ -263,13 +273,13 @@ class GemmArguments(RuntimeArguments):
|
||||
N: Number of columns in B and out
|
||||
|
||||
:param A: Input tensor A of shape (L, M, K) or (M, K)
|
||||
:type A: TensorWrapper
|
||||
:type A: TensorLike
|
||||
:param B: Input tensor B of shape (L, K, N) or (K, N)
|
||||
:type B: TensorWrapper
|
||||
:type B: TensorLike
|
||||
:param out: Output tensor C of shape (L, M, N) or (M, N)
|
||||
:type out: TensorWrapper
|
||||
:type out: TensorLike
|
||||
:param accumulator_type: Data type of the accumulator
|
||||
:type accumulator_type: cutlass.Numeric
|
||||
:type accumulator_type: NumericLike
|
||||
"""
|
||||
|
||||
A: TensorLike
|
||||
|
||||
@@ -81,3 +81,17 @@ class GlobalOptions:
|
||||
"TVM FFI is not installed, please install it via `pip install apache-tvm-ffi`."
|
||||
)
|
||||
self._options["use_tvm_ffi"] = value
|
||||
|
||||
def save(self, out: dict) -> None:
|
||||
"""
|
||||
Save the current options to a dictionary.
|
||||
"""
|
||||
for key, value in self._options.items():
|
||||
out[key] = value
|
||||
|
||||
def restore(self, inp: dict) -> None:
|
||||
"""
|
||||
Restore the options from a dictionary.
|
||||
"""
|
||||
for key, value in inp.items():
|
||||
self._options[key] = value
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
@@ -28,7 +27,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
|
||||
import ast
|
||||
import textwrap
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#################################################################################################
|
||||
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
@@ -28,7 +28,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
|
||||
"""
|
||||
Helpers for common activation functions
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
@@ -28,7 +27,6 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from cutlass_api.fusion.backend.sm80_emitter import Sm80Emitter
|
||||
import cutlass_api.fusion.backend.sm80_nodes as sm80_nodes
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#################################################################################################
|
||||
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
@@ -28,7 +28,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
|
||||
"""
|
||||
Base class for Epilogue Visitor Emitter
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#################################################################################################
|
||||
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
@@ -28,7 +28,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
|
||||
"""
|
||||
Emitter for Sm100 Epilogue Visitor
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#################################################################################################
|
||||
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
@@ -28,7 +28,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
|
||||
|
||||
from cutlass_api.fusion.ir import AuxLoadImpl, AuxStoreImpl
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#################################################################################################
|
||||
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
@@ -28,7 +28,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
|
||||
"""
|
||||
Emitter for Sm80 Epilogue Visitor
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#################################################################################################
|
||||
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
@@ -28,7 +28,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
|
||||
from cutlass_api.fusion.library import (
|
||||
DataTypeTag,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#################################################################################################
|
||||
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
@@ -28,7 +28,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
|
||||
"""
|
||||
Emitter for Sm90 Epilogue Visitor
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#################################################################################################
|
||||
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
@@ -28,7 +28,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
|
||||
|
||||
from cutlass_api.fusion.library import (
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#################################################################################################
|
||||
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
@@ -28,7 +28,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
|
||||
"""
|
||||
Epilogue Visitor interface for compiling, and running visitor-based epilogue.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#################################################################################################
|
||||
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
@@ -28,7 +28,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
|
||||
"""
|
||||
Collection of builtin functions used for host reference in EVT
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#################################################################################################
|
||||
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
@@ -28,7 +28,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
|
||||
from cutlass_api.fusion.frontend.python_ast import (
|
||||
PythonASTFrontend,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#################################################################################################
|
||||
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
@@ -28,7 +28,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
|
||||
"""
|
||||
Base class for Python EVT Frontend
|
||||
@@ -122,6 +122,9 @@ class EVTFrontendBase:
|
||||
# Parse the input
|
||||
self.parse(*args, **kwargs)
|
||||
|
||||
if not self.dag_ir.has_node("D"):
|
||||
raise RuntimeError("Output node D is not found in the epilogue function.")
|
||||
|
||||
# Verify the DAG IR to ensure that "D" is the output node with out_degree = 0
|
||||
if self.cc >= 90:
|
||||
if self.dag_ir.out_degree("D") != 0:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#################################################################################################
|
||||
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
@@ -28,7 +28,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
|
||||
"""
|
||||
Python AST frontend that parses input into DAG IR
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#################################################################################################
|
||||
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
@@ -28,7 +28,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
|
||||
from cutlass_api.fusion.ir.compute_nodes import ComputeNode, ComputeImpl
|
||||
from cutlass_api.fusion.ir.dag_ir import DAGIR
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#################################################################################################
|
||||
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
@@ -28,7 +28,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
|
||||
import ctypes
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#################################################################################################
|
||||
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
@@ -28,7 +28,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
|
||||
"""
|
||||
Python registration for compute nodes in EVT
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#################################################################################################
|
||||
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
@@ -28,7 +28,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
|
||||
"""
|
||||
DAG IR used by Python EVT
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#################################################################################################
|
||||
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
@@ -28,7 +28,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
|
||||
"""
|
||||
Layout algebras
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#################################################################################################
|
||||
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
@@ -28,7 +28,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
|
||||
"""
|
||||
Layout manipulation nodes and implementations
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#################################################################################################
|
||||
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
@@ -28,7 +28,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
|
||||
"""
|
||||
Load nodes and implementations
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#################################################################################################
|
||||
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
@@ -28,7 +28,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
|
||||
"""
|
||||
Base & visitor classes of DAGIR Nodes
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#################################################################################################
|
||||
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
@@ -28,7 +28,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
|
||||
"""
|
||||
Store node and implementations
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#################################################################################################
|
||||
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
@@ -28,7 +28,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
|
||||
"""
|
||||
High-level class for tensor
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#################################################################################################
|
||||
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
@@ -28,7 +28,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
|
||||
"""
|
||||
Copies of enum types from cutlass_library that are used in the fusion frontend.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#################################################################################################
|
||||
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
@@ -28,7 +28,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
|
||||
from cutlass_api.fusion.passes.graph_drawer import EVTGraphDrawer
|
||||
from cutlass_api.fusion.passes.pass_argument_type import PassGetArgumentType
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#################################################################################################
|
||||
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
@@ -28,7 +28,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#################################################################################################
|
||||
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
@@ -28,7 +28,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
|
||||
"""
|
||||
Construct the epilogue visitor argument type
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#################################################################################################
|
||||
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
@@ -28,7 +28,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
|
||||
"""
|
||||
Merge non-tree sub-graphs of the DAG IR into a single DAG. The fused DAG will be implemented
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#################################################################################################
|
||||
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
@@ -28,7 +28,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
|
||||
"""
|
||||
Fix the element_output of producer of D.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#################################################################################################
|
||||
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
@@ -28,7 +28,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
|
||||
"""
|
||||
Infer the underlying implement of each node.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#################################################################################################
|
||||
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
@@ -28,7 +28,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
|
||||
"""
|
||||
Eliminate layout manipulation nodes
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#################################################################################################
|
||||
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
@@ -28,7 +28,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
|
||||
"""
|
||||
Pass manager for DAG IR.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#################################################################################################
|
||||
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
@@ -28,7 +28,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
|
||||
"""
|
||||
No op elimination node
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#################################################################################################
|
||||
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
@@ -28,7 +28,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
|
||||
"""
|
||||
Preprocess the reduction nodes.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#################################################################################################
|
||||
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
@@ -28,7 +28,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
|
||||
"""
|
||||
Shape and type propagation pass
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#################################################################################################
|
||||
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
@@ -28,7 +28,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
|
||||
"""
|
||||
Compute the shared memory size in bytes
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#################################################################################################
|
||||
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
@@ -28,7 +28,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
|
||||
"""
|
||||
Utilities for passes
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#################################################################################################
|
||||
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
@@ -28,7 +28,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
|
||||
from .int_tuple import *
|
||||
from .layout import *
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#################################################################################################
|
||||
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
@@ -28,7 +28,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
|
||||
"""
|
||||
Functions for manipulating IntTuples
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#################################################################################################
|
||||
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
@@ -28,7 +28,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
|
||||
"""
|
||||
Definition of CuTe Layouts and functions to manipulate them
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#################################################################################################
|
||||
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
@@ -28,7 +28,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
|
||||
"""
|
||||
Methods for layout swizzling
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#################################################################################################
|
||||
|
||||
#
|
||||
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
@@ -28,7 +28,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
|
||||
from abc import ABC
|
||||
|
||||
|
||||
@@ -36,6 +36,7 @@ from cutlass_api.arguments import EpilogueArguments, RuntimeArguments
|
||||
from cutlass_api.artifact import CompiledArtifact
|
||||
from cutlass_api.metadata import KernelMetadata
|
||||
from cutlass_api.status import Status
|
||||
from cutlass_api.utils import device_cc
|
||||
|
||||
|
||||
class Kernel(ABC):
|
||||
@@ -43,6 +44,16 @@ class Kernel(ABC):
|
||||
Base class for all kernels to be implemented in providers
|
||||
"""
|
||||
|
||||
def __init__(self, metadata: KernelMetadata):
|
||||
self._metadata = metadata
|
||||
|
||||
@property
|
||||
def metadata(self) -> KernelMetadata:
|
||||
"""
|
||||
The read-only metadata for the kernel.
|
||||
"""
|
||||
return self._metadata
|
||||
|
||||
@final
|
||||
def supports(self, args: RuntimeArguments) -> Status:
|
||||
"""
|
||||
|
||||
@@ -90,7 +90,6 @@ class Manifest:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
epilogue_args = None if args is None else args.epilogue
|
||||
kernels = [
|
||||
k
|
||||
|
||||
@@ -68,50 +68,74 @@ def _convert_stride(shape: tuple[int, ...], stride: tuple[int, ...]) -> tuple[in
|
||||
return new_stride
|
||||
|
||||
|
||||
def _get_max_pow2_alignment(
|
||||
shape: tuple[int, ...], stride: tuple[int, ...], dtype: cutlass.Numeric
|
||||
) -> int:
|
||||
def _is_tuple_aligned(tup: tuple[int], divisibility: int, contiguous_dim: int) -> bool:
|
||||
"""
|
||||
Get the maximum power of 2 alignment for a given data type
|
||||
Check if the all elements of the shape/stride tuple are divisible by a given
|
||||
divisibility, except along the contiguous dimension.
|
||||
"""
|
||||
return all(
|
||||
t % divisibility == 0 for dim, t in enumerate(tup) if dim != contiguous_dim
|
||||
)
|
||||
|
||||
|
||||
def _get_max_pow2_divisibility(shape: tuple[int, ...], stride: tuple[int, ...]) -> int:
|
||||
"""
|
||||
Get the maximum power of 2 divisibility met by the given shape and stride.
|
||||
This is the largest power of 2 that divides the number of elements in the major mode (with stride 1).
|
||||
|
||||
:param shape: the shape of the tensor
|
||||
:type shape: tuple[int, ...]
|
||||
:param stride: the stride of the tensor
|
||||
:type stride: tuple[int, ...]
|
||||
:param dtype: the data type
|
||||
:type dtype: cutlass.Numeric
|
||||
|
||||
:return: the maximum power of 2 alignment
|
||||
:return: the maximum power of 2 divisibility by which the shape is divisible
|
||||
:rtype: int
|
||||
"""
|
||||
if 1 not in stride:
|
||||
return 1
|
||||
|
||||
major_mode_idx = stride.index(1)
|
||||
num_major_elements = shape[major_mode_idx]
|
||||
for alignment in [128, 64, 32, 16, 8, 4, 2]:
|
||||
num_contiguous_elements = alignment * 8 // dtype.width
|
||||
if num_major_elements % num_contiguous_elements == 0:
|
||||
return alignment
|
||||
return 1
|
||||
|
||||
best_divisibility = 1
|
||||
while num_major_elements % (best_divisibility * 2) == 0:
|
||||
best_divisibility *= 2
|
||||
return best_divisibility
|
||||
|
||||
|
||||
@dataclass
|
||||
class TensorAttributes:
|
||||
"""
|
||||
Description of a single tensor. This includes the data type, stride, and alignment.
|
||||
Description of a single tensor. This includes the data type, stride, and divisibility.
|
||||
|
||||
:param dtype: The data type of the tensor.
|
||||
:type dtype: cutlass.Numeric
|
||||
:param stride: The stride of the tensor.
|
||||
:type stride: tuple[int, ...]
|
||||
:param alignment: The alignment of the tensor
|
||||
:type alignment: int
|
||||
:param divisibility: The divisibility requirement on a tensor's stride & shape elements
|
||||
:type divisibility: int
|
||||
:param ptr_alignment_bytes: The alignment of the tensor's data pointer, in bytes.
|
||||
By default, it matches the number of bytes in stride/shape alignment.
|
||||
:type ptr_alignment_bytes: int
|
||||
"""
|
||||
|
||||
dtype: cutlass.Numeric # F32, F16, etc.
|
||||
stride: tuple[int, ...]
|
||||
alignment: int
|
||||
divisibility: int
|
||||
ptr_alignment_bytes: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dtype: cutlass.Numeric,
|
||||
stride: tuple[int, ...],
|
||||
divisibility: int,
|
||||
ptr_alignment_bytes: int | None = None,
|
||||
):
|
||||
self.dtype = dtype
|
||||
self.stride = stride
|
||||
self.divisibility = divisibility
|
||||
self.ptr_alignment_bytes = ptr_alignment_bytes or (
|
||||
(divisibility * dtype.width) // 8
|
||||
)
|
||||
|
||||
def supports(self, operand: TensorWrapper | Self) -> Status:
|
||||
"""
|
||||
@@ -159,22 +183,37 @@ class TensorAttributes:
|
||||
# When setting stride from args, any modes of stride 1 and shape 1
|
||||
# are changed to have stride 0. Thus, there will only be one mode
|
||||
# with stride 1.
|
||||
if not all_zeros and normalized_operand_stride[expected_stride.index(1)] != 1:
|
||||
contiguous_dim = expected_stride.index(1)
|
||||
if not all_zeros and normalized_operand_stride[contiguous_dim] != 1:
|
||||
return Status.fail(
|
||||
f"Expected stride[{expected_stride.index(1)}] to be 1, got {normalized_operand_stride[expected_stride.index(1)]} (strides: {normalized_operand_stride})"
|
||||
f"Expected stride[{contiguous_dim}] to be 1, got "
|
||||
f"{normalized_operand_stride[contiguous_dim]} "
|
||||
f"(strides: {normalized_operand_stride})"
|
||||
)
|
||||
|
||||
# Alignment of operand should be divisible by this metadata's alignment
|
||||
# Check that divisibility constraints are met
|
||||
if isinstance(operand, TensorWrapper):
|
||||
operand_alignment = _get_max_pow2_alignment(
|
||||
operand.shape, normalized_operand_stride, operand.element_type
|
||||
)
|
||||
if not _is_tuple_aligned(
|
||||
normalized_operand_stride, self.divisibility, contiguous_dim
|
||||
):
|
||||
return Status.fail(
|
||||
f"Expected operand stride to be divisible by {self.divisibility} for"
|
||||
f"all non-contiguous modes, got {normalized_operand_stride}"
|
||||
)
|
||||
else:
|
||||
operand_alignment = operand.alignment
|
||||
# When comparing another TensorAttribute, ensure its divisibility is a subset
|
||||
if operand.divisibility % self.divisibility != 0:
|
||||
return Status.fail(
|
||||
f"Expected operand divisibility {operand.divisibility} to be divisible by {self.divisibility}"
|
||||
)
|
||||
|
||||
if operand_alignment % self.alignment != 0:
|
||||
# Check data ptr alignment, if available
|
||||
if (
|
||||
isinstance(operand, TensorWrapper)
|
||||
and operand.data_ptr % self.ptr_alignment_bytes != 0
|
||||
):
|
||||
return Status.fail(
|
||||
f"Expected operand alignment {operand_alignment} (strides: {normalized_operand_stride}) to be a multiple of {self.alignment}"
|
||||
f"Expected data pointer to be {self.ptr_alignment_bytes}B-aligned."
|
||||
)
|
||||
|
||||
return Status.success()
|
||||
@@ -191,11 +230,9 @@ class TensorAttributes:
|
||||
:rtype: TensorAttributes
|
||||
"""
|
||||
stride = _convert_stride(tensor.shape, tensor.stride)
|
||||
max_alignment = _get_max_pow2_alignment(
|
||||
tensor.shape, stride, tensor.element_type
|
||||
)
|
||||
max_divisibility = _get_max_pow2_divisibility(tensor.shape, stride)
|
||||
return TensorAttributes(
|
||||
dtype=tensor.element_type, stride=stride, alignment=max_alignment
|
||||
dtype=tensor.element_type, stride=stride, divisibility=max_divisibility
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -26,9 +26,8 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
from collections.abc import Callable
|
||||
import logging
|
||||
from typing import Type
|
||||
from collections.abc import Callable
|
||||
|
||||
from cutlass_api.arguments import EpilogueArguments
|
||||
from cutlass_api.kernel import Kernel
|
||||
@@ -64,6 +63,7 @@ if available:
|
||||
epilogue_args: EpilogueArguments = None,
|
||||
cc: int = None,
|
||||
) -> list[Kernel]:
|
||||
|
||||
kernels_for_provider = []
|
||||
for kernel_cls in cls._kernel_classes:
|
||||
kernels_for_provider.extend(
|
||||
|
||||
@@ -32,6 +32,9 @@ from typing import ClassVar
|
||||
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
|
||||
from cutlass_api.arguments import ElementwiseArguments, EpilogueArguments
|
||||
from cutlass_api.artifact import CompiledArtifact
|
||||
from cutlass_api.metadata import (
|
||||
@@ -43,9 +46,6 @@ from cutlass_api.providers.cutedsl import CuTeDSLProvider
|
||||
from cutlass_api.providers.cutedsl.kernel import CuteDslKernel
|
||||
from cutlass_api.utils import to_cuda_stream
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
|
||||
"""
|
||||
An Elementwise Addition kernel using CuTe DSL:
|
||||
out = A + B
|
||||
@@ -60,7 +60,7 @@ class ElementwiseAddKernel(CuteDslKernel):
|
||||
]
|
||||
|
||||
def __init__(self, metadata: KernelMetadata):
|
||||
self.metadata = metadata
|
||||
super().__init__(metadata)
|
||||
self.impl = ElementwiseAddKernelImpl()
|
||||
|
||||
def compile(self, args: ElementwiseArguments, cc: int = None) -> CompiledArtifact:
|
||||
@@ -99,7 +99,7 @@ class ElementwiseAddKernel(CuteDslKernel):
|
||||
|
||||
kernel_list = []
|
||||
for dtype in ElementwiseAddKernel._supported_dtypes:
|
||||
alignment = 128 // dtype.width
|
||||
divisibility = 128 // dtype.width
|
||||
for stride_A, stride_B, stride_out in product(
|
||||
stride_names.keys(), repeat=3
|
||||
):
|
||||
@@ -112,17 +112,20 @@ class ElementwiseAddKernel(CuteDslKernel):
|
||||
|
||||
operands = ElementwiseOperandsMetadata(
|
||||
A=TensorAttributes(
|
||||
dtype=dtype, stride=stride_A, alignment=alignment
|
||||
dtype=dtype, stride=stride_A, divisibility=divisibility
|
||||
),
|
||||
B=TensorAttributes(
|
||||
dtype=dtype, stride=stride_B, alignment=alignment
|
||||
dtype=dtype, stride=stride_B, divisibility=divisibility
|
||||
),
|
||||
out=TensorAttributes(
|
||||
dtype=dtype, stride=stride_out, alignment=alignment
|
||||
dtype=dtype, stride=stride_out, divisibility=divisibility
|
||||
),
|
||||
)
|
||||
metadata = KernelMetadata(
|
||||
operands=operands, kernel_name=kernel_name, kernel_class=ElementwiseAddKernel, min_cc=min_cc
|
||||
operands=operands,
|
||||
kernel_name=kernel_name,
|
||||
kernel_class=ElementwiseAddKernel,
|
||||
min_cc=min_cc,
|
||||
)
|
||||
if metadata_filter(metadata):
|
||||
kernel_list.append(ElementwiseAddKernel(metadata))
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#################################################################################################
|
||||
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
@@ -28,7 +28,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
|
||||
@@ -31,3 +31,4 @@
|
||||
# ruff: noqa: F401
|
||||
import cutlass_api.providers.cutedsl.gemm.sm100_static_persistent
|
||||
import cutlass_api.providers.cutedsl.gemm.sm100_static_persistent_efc
|
||||
import cutlass_api.providers.cutedsl.gemm.sm80_tensorop_gemm
|
||||
|
||||
@@ -27,18 +27,15 @@
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Union
|
||||
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.cute.testing as testing
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
import cutlass.utils as utils
|
||||
import cutlass.pipeline as pipeline
|
||||
from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait
|
||||
import cutlass.utils as utils
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait
|
||||
|
||||
"""
|
||||
A high-performance persistent batched dense GEMM example for the NVIDIA Blackwell SM100 architecture
|
||||
@@ -1222,7 +1219,7 @@ class PersistentDenseGemmKernelImpl:
|
||||
tAcc: cute.Tensor,
|
||||
gC_mnl: cute.Tensor,
|
||||
epi_tile: cute.Tile,
|
||||
use_2cta_instrs: Union[cutlass.Boolean, bool],
|
||||
use_2cta_instrs: cutlass.Boolean | bool,
|
||||
) -> tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]:
|
||||
"""
|
||||
Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array (destination).
|
||||
|
||||
@@ -0,0 +1,814 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import argparse
|
||||
import math
|
||||
from typing import Tuple, Type
|
||||
|
||||
import torch
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.cute.testing as testing
|
||||
import cutlass.torch as cutlass_torch
|
||||
import cutlass.utils as utils
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
|
||||
"""
|
||||
A high-performance batched dense GEMM example for the utilzing SM80 style TensorCore instructions
|
||||
- Matrix A is LxMxK, L is batch dimension, A can be row-major("t") or column-major("n")
|
||||
- Matrix B is LxKxN, L is batch dimension, B can be row-major("t") or column-major("n")
|
||||
- Matrix C is LxMxN, L is batch dimension, C can be row-major("t") or column-major("n")
|
||||
- Internally, this is treated as an MxKxL, NxKxL, MxNxL CuTe Tensors
|
||||
|
||||
This GEMM kernel supports the following features:
|
||||
- Utilizes SM80's tensor cores for matrix multiply-accumulate (MMA) operations
|
||||
- Threadblock rasterization to improve data re-use
|
||||
- Supports multi-stage pipeline to overlap computation and memory access
|
||||
|
||||
Constraints:
|
||||
* Supported input and output data types: fp16, bf16
|
||||
* Support accumulator data types: f32
|
||||
* Atom layout's MNK shape is set so that tile shape can be divided by MMA instruction shape
|
||||
* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned,
|
||||
"""
|
||||
|
||||
class Sm80TensorOpGemmImpl:
|
||||
def __init__(
|
||||
self,
|
||||
a_dtype: Type[cutlass.Numeric],
|
||||
b_dtype: Type[cutlass.Numeric],
|
||||
c_dtype: Type[cutlass.Numeric],
|
||||
acc_dtype: Type[cutlass.Numeric],
|
||||
):
|
||||
self.ab_dtype = a_dtype
|
||||
self.c_dtype = c_dtype
|
||||
self.acc_dtype = acc_dtype
|
||||
self.cta_tiler = (128, 128, 32)
|
||||
self.num_stages = 3
|
||||
self.atom_layout_mnk = (2, 2, 1)
|
||||
atom_lay_M, atom_lay_N, atom_lay_K = self.atom_layout_mnk
|
||||
self.num_threads = atom_lay_M * atom_lay_N * atom_lay_K * 32
|
||||
|
||||
self.bM, self.bN, self.bK = self.cta_tiler
|
||||
self.mma_inst_shape = (16, 8, 16)
|
||||
mmaM, mmaN, mmaK = self.mma_inst_shape
|
||||
|
||||
assert a_dtype == b_dtype, "This example does not support different A, B dtypes"
|
||||
assert self.bM % (atom_lay_M * mmaM) == 0, (
|
||||
"bM must be divisible by MMA instruction"
|
||||
)
|
||||
assert self.bN % (atom_lay_N * mmaN) == 0, (
|
||||
"bN must be divisible by MMA instruction"
|
||||
)
|
||||
assert atom_lay_K == 1, "this example does not support atom layout K > 1"
|
||||
assert self.bK % mmaK == 0, "bK must be divisible by MMA instruction"
|
||||
assert self.num_stages >= 3, "num_stages must be greater than or equal to 3"
|
||||
|
||||
@cute.jit
|
||||
def __call__(
|
||||
self,
|
||||
a: cute.Tensor,
|
||||
b: cute.Tensor,
|
||||
c: cute.Tensor,
|
||||
stream,
|
||||
epilogue_op: cutlass.Constexpr = lambda x: x,
|
||||
):
|
||||
|
||||
def add_batch_mode(tensor: cute.Tensor) -> cute.Tensor:
|
||||
return cute.make_tensor(
|
||||
tensor.iterator,
|
||||
cute.prepend(tensor.layout, cute.make_layout(1), up_to_rank=3),
|
||||
)
|
||||
a = add_batch_mode(a)
|
||||
b = add_batch_mode(b)
|
||||
c = add_batch_mode(c)
|
||||
|
||||
# Permute tensor modes from torch to cute convention
|
||||
# A: (L, M, K) -> (M, K, L)
|
||||
a = cute.make_tensor(a.iterator, cute.select(a.layout, [1, 2, 0]))
|
||||
# B: (L, K, N) -> (N, K, L)
|
||||
b = cute.make_tensor(b.iterator, cute.select(b.layout, [2, 1, 0]))
|
||||
# C: (L, M, N) -> (M, N, L)
|
||||
c = cute.make_tensor(c.iterator, cute.select(c.layout, [1, 2, 0]))
|
||||
|
||||
# The grid divides the problems's M, N, and L dimensions by the
|
||||
# respective modes of the tile shape (bM, bN, 1). The K dimension is
|
||||
# handled within a block via a multistage process.
|
||||
|
||||
self.a_major_mode = utils.LayoutEnum.from_tensor(a)
|
||||
self.b_major_mode = utils.LayoutEnum.from_tensor(b)
|
||||
self.c_major_mode = utils.LayoutEnum.from_tensor(c)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Shared memory layout:
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
# Creates a layout with the size required for the provided tile
|
||||
# size and num stages (stages are used for K dimension) that is also
|
||||
# sectioned into 64x8 or 8x32 layout atoms. The swizzle is set so that
|
||||
# the atom for the shared memory -> register copy does not encounter
|
||||
# bank conflicts
|
||||
|
||||
# assume the input is 16B align
|
||||
ab_copy_bits = 128
|
||||
sA_layout = self._make_smem_layout_AB(
|
||||
a.element_type,
|
||||
self.a_major_mode,
|
||||
ab_copy_bits,
|
||||
(self.cta_tiler[0], self.cta_tiler[2], self.num_stages),
|
||||
)
|
||||
sB_layout = self._make_smem_layout_AB(
|
||||
b.element_type,
|
||||
self.b_major_mode,
|
||||
ab_copy_bits,
|
||||
(self.cta_tiler[1], self.cta_tiler[2], self.num_stages),
|
||||
)
|
||||
|
||||
# Creates a similar layout but without num_stages or layout atoms
|
||||
sC_layout = self._make_smem_layout_C(
|
||||
c.element_type,
|
||||
self.c_major_mode,
|
||||
ab_copy_bits,
|
||||
(self.cta_tiler[0], self.cta_tiler[1]),
|
||||
)
|
||||
|
||||
# Shared memory allocated for operations with A, B will be
|
||||
# overwritten for operations on C. This is to improve performance
|
||||
# by reducing the size of shared memory requested by each block
|
||||
smem_size = max(
|
||||
cute.size_in_bytes(c.element_type, sC_layout),
|
||||
cute.size_in_bytes(a.element_type, sA_layout)
|
||||
+ cute.size_in_bytes(b.element_type, sB_layout),
|
||||
)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Tiled copy:
|
||||
# The majorness of tA/tB/tC follows the majorness of gA/gB/gC,
|
||||
# enabling merged accesses to global memory for faster data
|
||||
# transfer between global and shared memory.
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
# Create a copy atom for a global to shared memory asynchronous copy
|
||||
atom_async_copy = cute.make_copy_atom(
|
||||
cute.nvgpu.cpasync.CopyG2SOp(
|
||||
cache_mode=cute.nvgpu.cpasync.LoadCacheMode.GLOBAL
|
||||
),
|
||||
a.element_type,
|
||||
num_bits_per_copy=ab_copy_bits,
|
||||
)
|
||||
|
||||
# Create thread layouts for tiled copy from the copy atom where the
|
||||
# thread layout simply follows the leading dimension of the tensor
|
||||
tiled_copy_A = self._make_gmem_tiled_copy_AB(
|
||||
atom_async_copy, a.element_type, self.a_major_mode, ab_copy_bits
|
||||
)
|
||||
tiled_copy_B = self._make_gmem_tiled_copy_AB(
|
||||
atom_async_copy, b.element_type, self.b_major_mode, ab_copy_bits
|
||||
)
|
||||
|
||||
# Creates a synchronous copy atom and thread layouts for the epilogue
|
||||
c_copy_bits = 128
|
||||
atom_sync_copy = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(),
|
||||
c.element_type,
|
||||
num_bits_per_copy=c_copy_bits,
|
||||
)
|
||||
tiled_copy_C = self._make_gmem_tiled_copy_C(
|
||||
atom_sync_copy, c.element_type, self.c_major_mode, c_copy_bits
|
||||
)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Tiled MMA
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
# Creates a mma atom with 16x8x16 shape for MNK
|
||||
op = cute.nvgpu.warp.MmaF16BF16Op(
|
||||
self.ab_dtype, self.acc_dtype, self.mma_inst_shape
|
||||
)
|
||||
|
||||
permutation_mnk = (
|
||||
self.atom_layout_mnk[0] * self.mma_inst_shape[0],
|
||||
# if atom layout's N-mode is 1, to leverage the largest coalesced
|
||||
# shared memory -> register copy, set the tiled mma's N mode to 16
|
||||
self.atom_layout_mnk[1] * self.mma_inst_shape[1] * 2,
|
||||
self.atom_layout_mnk[2] * self.mma_inst_shape[2],
|
||||
)
|
||||
|
||||
# Created a tiled mma that tiles the atom according to specified layout.
|
||||
# For a 2x2x1 atom layout, the mma atom is duplicated 4 times, twice
|
||||
# across M and twice across N
|
||||
tC = cute.make_layout(self.atom_layout_mnk)
|
||||
tiled_mma = cute.make_tiled_mma(
|
||||
op,
|
||||
tC,
|
||||
permutation_mnk=permutation_mnk,
|
||||
)
|
||||
|
||||
# grid_dim: ((m + BLK_M - 1) // BLK_M, (n + BLK_N - 1) // BLK_N, l)
|
||||
grid_dim = cute.ceil_div(c.shape, (self.bM, self.bN, 1))
|
||||
|
||||
# Add threadblock rasterization to improve re-use of data
|
||||
raster_factor = 1
|
||||
grid_dim_n = cute.size(grid_dim[1])
|
||||
# Thresholds picked so that it doesn't cause too many no-op CTAs
|
||||
if grid_dim_n > 5:
|
||||
raster_factor = 8
|
||||
elif grid_dim_n > 2:
|
||||
raster_factor = 4
|
||||
elif grid_dim_n > 1:
|
||||
raster_factor = 2
|
||||
rasterization_remap_grid_dim = (
|
||||
cute.size(grid_dim[0]) * raster_factor,
|
||||
(cute.size(grid_dim[1]) + raster_factor - 1) // raster_factor,
|
||||
cute.size(grid_dim[2]),
|
||||
)
|
||||
|
||||
self.kernel(
|
||||
a,
|
||||
b,
|
||||
c,
|
||||
sA_layout,
|
||||
sB_layout,
|
||||
sC_layout,
|
||||
tiled_copy_A,
|
||||
tiled_copy_B,
|
||||
tiled_copy_C,
|
||||
tiled_mma,
|
||||
raster_factor,
|
||||
epilogue_op,
|
||||
).launch(
|
||||
grid=rasterization_remap_grid_dim,
|
||||
block=[self.num_threads, 1, 1],
|
||||
smem=smem_size,
|
||||
)
|
||||
|
||||
@cute.kernel
|
||||
def kernel(
|
||||
self,
|
||||
mA: cute.Tensor,
|
||||
mB: cute.Tensor,
|
||||
mC: cute.Tensor,
|
||||
sA_layout: cute.ComposedLayout,
|
||||
sB_layout: cute.ComposedLayout,
|
||||
sC_layout: cute.ComposedLayout,
|
||||
tiled_copy_A: cute.TiledCopy,
|
||||
tiled_copy_B: cute.TiledCopy,
|
||||
tiled_copy_C: cute.TiledCopy,
|
||||
tiled_mma: cute.TiledMma,
|
||||
rasterization_factor: cutlass.Int32,
|
||||
epilogue_op: cutlass.Constexpr = lambda x: x,
|
||||
):
|
||||
# Thread index, block index
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
bidx, bidy, bidz = cute.arch.block_idx()
|
||||
grid_dim = cute.ceil_div(mC.shape, (self.bM, self.bN, 1))
|
||||
offset_tile_x, offset_tile_y = self.raster_tile(
|
||||
bidx, bidy, rasterization_factor
|
||||
)
|
||||
# Early exit if CTA is out of range
|
||||
if grid_dim[0] <= offset_tile_x or grid_dim[1] <= offset_tile_y:
|
||||
pass
|
||||
else:
|
||||
tiler_coord = (offset_tile_x, offset_tile_y, None)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Get the appropriate tiles for this thread block.
|
||||
# gA: (BLK_M, BLK_N, k), gB: (BLK_N, BLK_K, k), gC: (BLK_M, BLK_N)
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
gA = cute.local_tile(
|
||||
mA[None, None, bidz],
|
||||
tiler=self.cta_tiler,
|
||||
coord=tiler_coord,
|
||||
proj=(1, None, 1),
|
||||
)
|
||||
gB = cute.local_tile(
|
||||
mB[None, None, bidz],
|
||||
tiler=self.cta_tiler,
|
||||
coord=tiler_coord,
|
||||
proj=(None, 1, 1),
|
||||
)
|
||||
gC = cute.local_tile(
|
||||
mC[None, None, bidz],
|
||||
tiler=self.cta_tiler,
|
||||
coord=tiler_coord,
|
||||
proj=(1, 1, None),
|
||||
)
|
||||
|
||||
# By default, if the tensor k mode does not divide into the tile k
|
||||
# size, then last tiles in the k dimension are irregular.
|
||||
# Instead, make the first tiles irregular when k is irregular.
|
||||
# This allows us to handle the irregular tile first to avoid
|
||||
# checking for this condition within the mainloop.
|
||||
|
||||
# residual_k is a negative number indicating the amount needed to
|
||||
# shift the pointer by in dimension k
|
||||
residual_k = cute.size(mA, mode=[1]) - cutlass.Int32(self.bK) * cute.size(
|
||||
gA, mode=[2]
|
||||
)
|
||||
|
||||
# move the pointer of gA/gB in the `-k` direction
|
||||
gA = cute.domain_offset((0, residual_k, 0), gA)
|
||||
gB = cute.domain_offset((0, residual_k, 0), gB)
|
||||
# input is 16B aligned
|
||||
gA = cute.make_tensor(gA.iterator.align(16), gA.layout)
|
||||
gB = cute.make_tensor(gB.iterator.align(16), gB.layout)
|
||||
|
||||
# Construct identity layout for sA and sB (mirrors global tensors,
|
||||
# used for predication only)
|
||||
mcA = cute.make_identity_tensor(mA.layout.shape)
|
||||
mcB = cute.make_identity_tensor(mB.layout.shape)
|
||||
cA = cute.local_tile(
|
||||
mcA[None, None, bidz],
|
||||
tiler=self.cta_tiler,
|
||||
coord=tiler_coord,
|
||||
proj=(1, None, 1),
|
||||
)
|
||||
cB = cute.local_tile(
|
||||
mcB[None, None, bidz],
|
||||
tiler=self.cta_tiler,
|
||||
coord=tiler_coord,
|
||||
proj=(None, 1, 1),
|
||||
)
|
||||
|
||||
cA = cute.domain_offset((0, residual_k, 0), cA)
|
||||
cB = cute.domain_offset((0, residual_k, 0), cB)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Create shared memory buffers and get the appropriate fragments for this thread.
|
||||
# sA: (BLK_M, BLK_K, PIPE) , sB: (BLK_N, BLK_K, PIPE)
|
||||
# tAgA: (CPY, CPY_M, CPY_K, k) , tBgB: (CPY, CPY_N, CPY_K, k)
|
||||
# tAsA: (CPY, CPY_M, CPY_K, PIPE) , tBsB: (CPY, CPY_N, CPY_K, PIPE)
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Shared memory buffer
|
||||
smem = cutlass.utils.SmemAllocator()
|
||||
|
||||
sA = smem.allocate_tensor(mA.element_type, sA_layout, 16)
|
||||
sB = smem.allocate_tensor(mB.element_type, sB_layout, 16)
|
||||
sC = cute.make_tensor(
|
||||
cute.recast_ptr(sA.iterator, dtype=self.c_dtype), sC_layout
|
||||
)
|
||||
|
||||
thr_copy_A = tiled_copy_A.get_slice(tidx)
|
||||
thr_copy_B = tiled_copy_B.get_slice(tidx)
|
||||
thr_copy_C = tiled_copy_C.get_slice(tidx)
|
||||
tAgA = thr_copy_A.partition_S(gA)
|
||||
tAsA = thr_copy_A.partition_D(sA)
|
||||
tBgB = thr_copy_B.partition_S(gB)
|
||||
tBsB = thr_copy_B.partition_D(sB)
|
||||
tCsC_epilogue = thr_copy_C.partition_S(sC)
|
||||
tCgC_epilogue = thr_copy_C.partition_D(gC)
|
||||
|
||||
# Repeat the partitioning with identity layouts
|
||||
tAcA = thr_copy_A.partition_S(cA)
|
||||
tBcB = thr_copy_B.partition_S(cB)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Predicate: Mark indices that need to copy when problem_shape isn't a multiple
|
||||
# of tile_shape
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
# For predication over the tensors A (M/K), B (N/K), and (in the
|
||||
# epilogue) C (M/N), we will compute it in a fashion similar to an
|
||||
# outer product. The predication along one of the dimensions is
|
||||
# evaluated and stored in a predication tensor. Then, the
|
||||
# predication for the remaining dimension is handled later via an
|
||||
# if/else branch at the copy.
|
||||
# For A and B, predication booleans along M/N are stored in a
|
||||
# predication tensor and along K is handled via a if/else branch.
|
||||
|
||||
# Allocate predicate tensors for M and N. Predication is checked
|
||||
# at the granularity of a copy atom, so the predicate tensor does not
|
||||
# need separate booleans for individual elements within a copy
|
||||
# atom (for example, the elements of tAgA.shape[0][0].)
|
||||
tApA = cute.make_rmem_tensor(
|
||||
cute.make_layout(
|
||||
(
|
||||
tAgA.shape[0][1],
|
||||
cute.size(tAgA, mode=[1]),
|
||||
cute.size(tAgA, mode=[2]),
|
||||
),
|
||||
stride=(cute.size(tAgA, mode=[1]), 1, 0),
|
||||
),
|
||||
cutlass.Boolean,
|
||||
)
|
||||
tBpB = cute.make_rmem_tensor(
|
||||
cute.make_layout(
|
||||
(
|
||||
tBsB.shape[0][1],
|
||||
cute.size(tBsB, mode=[1]),
|
||||
cute.size(tBsB, mode=[2]),
|
||||
),
|
||||
stride=(cute.size(tBsB, mode=[1]), 1, 0),
|
||||
),
|
||||
cutlass.Boolean,
|
||||
)
|
||||
# Set predicates for M/N bounds
|
||||
for rest_v in range(tApA.shape[0]):
|
||||
for m in range(tApA.shape[1]):
|
||||
tApA[rest_v, m, 0] = cute.elem_less(
|
||||
tAcA[(0, rest_v), m, 0, 0][0], mA.shape[0]
|
||||
)
|
||||
for rest_v in range(tBpB.shape[0]):
|
||||
for n in range(tBpB.shape[1]):
|
||||
tBpB[rest_v, n, 0] = cute.elem_less(
|
||||
tBcB[(0, rest_v), n, 0, 0][0], mB.shape[0]
|
||||
)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Prefetch Prologue
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Clear the smem tiles to account for predicated off loads
|
||||
tAsA.fill(0)
|
||||
tBsB.fill(0)
|
||||
cute.arch.sync_threads()
|
||||
# Start async loads for the first k-tile. Here we take care of the k residue
|
||||
# via if/else check along the k dimension. Because we shifted the identity tensor
|
||||
# by the residue_k and because the identity tensor is a coord tensor, the
|
||||
# values of any identity tensor element that is poison is less than -1
|
||||
num_smem_stages = cute.size(tAsA, mode=[3])
|
||||
k_tile_count = cute.size(tAgA, mode=[3])
|
||||
k_tile_index = cutlass.Int32(0)
|
||||
|
||||
for k in range(tApA.shape[2]):
|
||||
if cute.elem_less(cutlass.Int32(-1), tAcA[0, 0, k, 0][1]):
|
||||
cute.copy(
|
||||
tiled_copy_A,
|
||||
tAgA[None, None, k, k_tile_index],
|
||||
tAsA[None, None, k, 0],
|
||||
pred=tApA[None, None, k],
|
||||
)
|
||||
for k in range(tBpB.shape[2]):
|
||||
if cute.elem_less(cutlass.Int32(-1), tBcB[0, 0, k, 0][1]):
|
||||
cute.copy(
|
||||
tiled_copy_B,
|
||||
tBgB[None, None, k, k_tile_index],
|
||||
tBsB[None, None, k, 0],
|
||||
pred=tBpB[None, None, k],
|
||||
)
|
||||
k_tile_index = k_tile_index + 1
|
||||
cute.arch.cp_async_commit_group()
|
||||
|
||||
# Start async loads for rest of the k-tiles
|
||||
for k_tile in range(1, num_smem_stages - 1):
|
||||
if k_tile == k_tile_count:
|
||||
tApA.fill(0)
|
||||
tBpB.fill(0)
|
||||
cute.copy(
|
||||
tiled_copy_A,
|
||||
tAgA[None, None, None, k_tile_index],
|
||||
tAsA[None, None, None, k_tile],
|
||||
pred=tApA,
|
||||
)
|
||||
cute.copy(
|
||||
tiled_copy_B,
|
||||
tBgB[None, None, None, k_tile_index],
|
||||
tBsB[None, None, None, k_tile],
|
||||
pred=tBpB,
|
||||
)
|
||||
k_tile_index = k_tile_index + 1
|
||||
cute.arch.cp_async_commit_group()
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Tile MMA compute thread partitions and allocate accumulators
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
thr_mma = tiled_mma.get_slice(tidx)
|
||||
tCsA = thr_mma.partition_A(sA)
|
||||
tCsB = thr_mma.partition_B(sB)
|
||||
tCsC = thr_mma.partition_C(sC)
|
||||
tCgC = thr_mma.partition_C(gC)
|
||||
tCrA = tiled_mma.make_fragment_A(tCsA[None, None, None, 0])
|
||||
tCrB = tiled_mma.make_fragment_B(tCsB[None, None, None, 0])
|
||||
tCrC = tiled_mma.make_fragment_C(tCgC)
|
||||
# Clear the accumulator
|
||||
tCrC.fill(0.0)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Copy Atom A/B retiling
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
# Create the copy atoms for the copy from shared memory to register
|
||||
atom_copy_s2r_A = cute.make_copy_atom(
|
||||
cute.nvgpu.warp.LdMatrix8x8x16bOp(
|
||||
self.a_major_mode != utils.LayoutEnum.ROW_MAJOR, 4
|
||||
),
|
||||
mA.element_type,
|
||||
)
|
||||
atom_copy_s2r_B = cute.make_copy_atom(
|
||||
cute.nvgpu.warp.LdMatrix8x8x16bOp(
|
||||
self.b_major_mode != utils.LayoutEnum.ROW_MAJOR, 4
|
||||
),
|
||||
mB.element_type,
|
||||
)
|
||||
|
||||
# Creates the tiled copy so that it matches the thread-value layout
|
||||
# expected by the tiled mma
|
||||
tiled_copy_s2r_A = cute.make_tiled_copy_A(atom_copy_s2r_A, tiled_mma)
|
||||
tiled_copy_s2r_B = cute.make_tiled_copy_B(atom_copy_s2r_B, tiled_mma)
|
||||
|
||||
thr_copy_ldmatrix_A = tiled_copy_s2r_A.get_slice(tidx)
|
||||
thr_copy_ldmatrix_B = tiled_copy_s2r_B.get_slice(tidx)
|
||||
tCsA_copy_view = thr_copy_ldmatrix_A.partition_S(sA)
|
||||
tCrA_copy_view = thr_copy_ldmatrix_A.retile(tCrA)
|
||||
tCsB_copy_view = thr_copy_ldmatrix_B.partition_S(sB)
|
||||
tCrB_copy_view = thr_copy_ldmatrix_B.retile(tCrB)
|
||||
|
||||
# Current pipe index in smem to read from / write to
|
||||
smem_pipe_read = 0
|
||||
smem_pipe_write = num_smem_stages - 1
|
||||
|
||||
tCsA_p = tCsA_copy_view[None, None, None, smem_pipe_read]
|
||||
tCsB_p = tCsB_copy_view[None, None, None, smem_pipe_read]
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# PREFETCH register pipeline
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
num_k_block = cute.size(tCrA, mode=[2])
|
||||
if num_k_block > 1:
|
||||
# Wait until our first prefetched tile is loaded in
|
||||
cute.arch.cp_async_wait_group(num_smem_stages - 2)
|
||||
cute.arch.sync_threads()
|
||||
# Prefetch the first k-block rmem from the first k-tile
|
||||
cute.copy(
|
||||
tiled_copy_s2r_A,
|
||||
tCsA_p[None, None, 0],
|
||||
tCrA_copy_view[None, None, 0],
|
||||
)
|
||||
cute.copy(
|
||||
tiled_copy_s2r_B,
|
||||
tCsB_p[None, None, 0],
|
||||
tCrB_copy_view[None, None, 0],
|
||||
)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Mainloop
|
||||
# 1. Shared memory pipeline (gmem -> smem):
|
||||
# The default smem pipeline depth is 3, meaning that for shared
|
||||
# memory buffers, we allocate three times the size described by the
|
||||
# CTA tiler. We prefetch 2 of these buffers before entering the main
|
||||
# loop. Considering only the transfer from global memory to shared
|
||||
# memory, the general structure of the mainloop is:
|
||||
# (1) copy k-tile from gmem to smem;
|
||||
# (2) perform gemm computation on k-tile;
|
||||
# (3) wait for the next copy to finish.
|
||||
# The `cute.arch.cp_async_wait_group(num_smem_stages - 2)` command
|
||||
# waits for the number of unfinished 'copy' to be <= 1. The advantage
|
||||
# of this approach is that it allows for simultaneous production
|
||||
# (i.e., step (1)) and consumption (i.e., step (2)) of smem.
|
||||
# A common misconception is to prefetch N buffers and rewrite
|
||||
# the pipeline logic to wait on N-1 pending copies. The disadvantage
|
||||
# of this approach is that it requires fully consuming a buffer in
|
||||
# order to open an empty buffer for the next copy.
|
||||
# 2. Register pipeline (smem -> register):
|
||||
# Similarly, the register pipeline produces i+1, consumes i, and
|
||||
# produces i+2... Notably, i and i+1 do not use the same register,
|
||||
# eliminating dependencies on the same register for better parallelism.
|
||||
# 3. Combining the smem and register pipelines results in the mainloop.
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
for k_tile in range(k_tile_count):
|
||||
for k_block in cutlass.range(num_k_block, unroll_full=True):
|
||||
if k_block == num_k_block - 1:
|
||||
tCsA_p = tCsA_copy_view[None, None, None, smem_pipe_read]
|
||||
tCsB_p = tCsB_copy_view[None, None, None, smem_pipe_read]
|
||||
cute.arch.cp_async_wait_group(num_smem_stages - 2)
|
||||
cute.arch.sync_threads()
|
||||
|
||||
# Load A, B from shared memory to registers for k_block + 1
|
||||
k_block_next = (k_block + 1) % num_k_block # static
|
||||
cute.copy(
|
||||
tiled_copy_s2r_A,
|
||||
tCsA_p[None, None, k_block_next],
|
||||
tCrA_copy_view[None, None, k_block_next],
|
||||
)
|
||||
cute.copy(
|
||||
tiled_copy_s2r_B,
|
||||
tCsB_p[None, None, k_block_next],
|
||||
tCrB_copy_view[None, None, k_block_next],
|
||||
)
|
||||
|
||||
# Fetch next A: To better interleave global memory access and compute
|
||||
# instructions, we intentionally use the sequence: copy A, perform GEMM,
|
||||
# then copy B.
|
||||
if k_block == 0:
|
||||
if k_tile + num_smem_stages - 1 < k_tile_count:
|
||||
cute.copy(
|
||||
tiled_copy_A,
|
||||
tAgA[None, None, None, k_tile_index],
|
||||
tAsA[None, None, None, smem_pipe_write],
|
||||
pred=tApA,
|
||||
)
|
||||
|
||||
# Thread-level register gemm for k_block
|
||||
cute.gemm(
|
||||
tiled_mma,
|
||||
tCrC,
|
||||
tCrA[None, None, k_block],
|
||||
tCrB[None, None, k_block],
|
||||
tCrC,
|
||||
)
|
||||
|
||||
# Fetch next B and update smem pipeline read/write
|
||||
if k_block == 0:
|
||||
if k_tile + num_smem_stages - 1 < k_tile_count:
|
||||
cute.copy(
|
||||
tiled_copy_B,
|
||||
tBgB[None, None, None, k_tile_index],
|
||||
tBsB[None, None, None, smem_pipe_write],
|
||||
pred=tBpB,
|
||||
)
|
||||
k_tile_index = k_tile_index + 1
|
||||
cute.arch.cp_async_commit_group()
|
||||
smem_pipe_write = smem_pipe_read
|
||||
smem_pipe_read = smem_pipe_read + 1
|
||||
if smem_pipe_read == num_smem_stages:
|
||||
smem_pipe_read = 0
|
||||
|
||||
# Sync before epilogue
|
||||
cute.arch.cp_async_wait_group(0)
|
||||
cute.arch.sync_threads()
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Epilogue with fusion
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
tCrD = cute.make_fragment_like(tCrC, self.c_dtype)
|
||||
tCrD[None] = epilogue_op(tCrC.load()).to(self.c_dtype)
|
||||
|
||||
# Copy results of D back to shared memory
|
||||
cute.autovec_copy(tCrD, tCsC)
|
||||
|
||||
# Create coord tensor for C
|
||||
ceilM, ceilN, _ = cute.ceil_div(mC.shape, (self.bM, self.bN, 1))
|
||||
mcC = cute.make_identity_tensor(
|
||||
(
|
||||
cute.size(ceilM) * self.cta_tiler[0],
|
||||
cute.size(ceilN) * self.cta_tiler[1],
|
||||
1,
|
||||
)
|
||||
)
|
||||
cC = cute.local_tile(
|
||||
mcC[None, None, bidz],
|
||||
tiler=self.cta_tiler,
|
||||
coord=tiler_coord,
|
||||
proj=(1, 1, None),
|
||||
)
|
||||
tCcC = thr_copy_C.partition_S(cC)
|
||||
|
||||
tCrC_epilogue = cute.make_fragment_like(tCsC_epilogue)
|
||||
# Wait for all writes to shared memory to finish before starting copies
|
||||
# using the new layouts
|
||||
cute.arch.sync_threads()
|
||||
cute.autovec_copy(tCsC_epilogue, tCrC_epilogue)
|
||||
|
||||
# Create predication tensor for m
|
||||
tCpC = cute.make_rmem_tensor(
|
||||
cute.make_layout(
|
||||
(
|
||||
tCgC_epilogue.shape[0][1],
|
||||
cute.size(tCgC_epilogue, mode=[1]),
|
||||
cute.size(tCgC_epilogue, mode=[2]),
|
||||
),
|
||||
stride=(cute.size(tCgC_epilogue, mode=[1]), 1, 0),
|
||||
),
|
||||
cutlass.Boolean,
|
||||
)
|
||||
for rest_v in range(tCpC.shape[0]):
|
||||
for m in range(tCpC.shape[1]):
|
||||
tCpC[rest_v, m, 0] = cute.elem_less(
|
||||
tCcC[(0, rest_v), m, 0][0], mC.shape[0]
|
||||
)
|
||||
|
||||
# Copy to global memory using better vectorization
|
||||
for rest_v in range(tCpC.shape[0]):
|
||||
for n in range(tCpC.shape[2]):
|
||||
if cute.elem_less(tCcC[(0, rest_v), 0, n][1], mC.shape[1]):
|
||||
cute.copy(
|
||||
tiled_copy_C,
|
||||
tCrC_epilogue[None, None, n],
|
||||
tCgC_epilogue[None, None, n],
|
||||
pred=tCpC[None, None, n],
|
||||
)
|
||||
return
|
||||
|
||||
def _make_smem_layout_AB(self, dtype, major_mode, copy_bits, smem_tiler):
|
||||
major_mode_size = (
|
||||
smem_tiler[1] if major_mode == utils.LayoutEnum.ROW_MAJOR else smem_tiler[0]
|
||||
)
|
||||
major_mode_size = 64 if major_mode_size >= 64 else major_mode_size
|
||||
|
||||
swizzle_bits = int(math.log2(major_mode_size * dtype.width // copy_bits))
|
||||
swizzle_bits = min(swizzle_bits, 3)
|
||||
|
||||
layout_atom_outer = (
|
||||
cute.make_layout((8, major_mode_size), stride=(major_mode_size, 1))
|
||||
if major_mode == utils.LayoutEnum.ROW_MAJOR
|
||||
else cute.make_layout((major_mode_size, 8), stride=(1, major_mode_size))
|
||||
)
|
||||
layout_atom = cute.make_composed_layout(
|
||||
cute.make_swizzle(swizzle_bits, 3, 3),
|
||||
0,
|
||||
layout_atom_outer,
|
||||
)
|
||||
layout = cute.tile_to_shape(layout_atom, smem_tiler, (0, 1, 2))
|
||||
return layout
|
||||
|
||||
def _make_smem_layout_C(self, dtype, major_mode, copy_bits, smem_tiler):
|
||||
major_mode_size = (
|
||||
smem_tiler[1] if major_mode == utils.LayoutEnum.ROW_MAJOR else smem_tiler[0]
|
||||
)
|
||||
|
||||
swizzle_bits = int(math.log2(major_mode_size * dtype.width // copy_bits))
|
||||
swizzle_bits = min(swizzle_bits, 3)
|
||||
|
||||
layout_atom_outer = (
|
||||
cute.make_layout((8, major_mode_size), stride=(major_mode_size, 1))
|
||||
if major_mode == utils.LayoutEnum.ROW_MAJOR
|
||||
else cute.make_layout((major_mode_size, 8), stride=(1, major_mode_size))
|
||||
)
|
||||
layout_atom = cute.make_composed_layout(
|
||||
cute.make_swizzle(swizzle_bits, 3, 4),
|
||||
0,
|
||||
layout_atom_outer,
|
||||
)
|
||||
|
||||
# Due to the thread layout of the mma, remove swizzle in C to
|
||||
# prevent shared memory fragments owned by an single thread from
|
||||
# holding swizzles
|
||||
if major_mode == utils.LayoutEnum.COL_MAJOR:
|
||||
layout_atom = cute.make_composed_layout(
|
||||
cute.make_swizzle(0, 3, 4), 0, layout_atom_outer
|
||||
)
|
||||
layout = cute.tile_to_shape(
|
||||
layout_atom,
|
||||
smem_tiler,
|
||||
(0, 1),
|
||||
)
|
||||
return layout
|
||||
|
||||
def _make_gmem_tiled_copy_AB(self, atom_copy, dtype, major_mode, copy_bits):
|
||||
copy_elems = copy_bits // dtype.width
|
||||
shape_dim_1 = cute.size(self.bK) // copy_elems
|
||||
# thread layout for copy
|
||||
thread_layout = cute.make_layout(
|
||||
(self.num_threads // shape_dim_1, shape_dim_1), stride=(shape_dim_1, 1)
|
||||
)
|
||||
if major_mode != utils.LayoutEnum.ROW_MAJOR:
|
||||
shape_dim_0 = cute.size(self.bM) // copy_elems
|
||||
thread_layout = cute.make_layout(
|
||||
(shape_dim_0, self.num_threads // shape_dim_0), stride=(1, shape_dim_0)
|
||||
)
|
||||
# Value layout for copy
|
||||
value_layout = (
|
||||
cute.make_layout((1, copy_elems))
|
||||
if major_mode == utils.LayoutEnum.ROW_MAJOR
|
||||
else cute.make_layout((copy_elems, 1))
|
||||
)
|
||||
return cute.make_tiled_copy_tv(atom_copy, thread_layout, value_layout)
|
||||
|
||||
def _make_gmem_tiled_copy_C(self, atom_copy, dtype, major_mode, copy_bits):
|
||||
copy_elems = copy_bits // dtype.width
|
||||
shape_dim_1 = cute.size(self.bN) // copy_elems
|
||||
# thread layout for copy
|
||||
thread_layout = cute.make_layout(
|
||||
(self.num_threads // shape_dim_1, shape_dim_1), stride=(shape_dim_1, 1)
|
||||
)
|
||||
if major_mode != utils.LayoutEnum.ROW_MAJOR:
|
||||
shape_dim_0 = cute.size(self.bM) // copy_elems
|
||||
thread_layout = cute.make_layout(
|
||||
(shape_dim_0, self.num_threads // shape_dim_0), stride=(1, shape_dim_0)
|
||||
)
|
||||
value_layout = (
|
||||
cute.make_layout((1, copy_elems))
|
||||
if major_mode == utils.LayoutEnum.ROW_MAJOR
|
||||
else cute.make_layout((copy_elems, 1))
|
||||
)
|
||||
return cute.make_tiled_copy_tv(atom_copy, thread_layout, value_layout)
|
||||
|
||||
def raster_tile(self, i, j, f):
|
||||
new_i = i // f
|
||||
new_j = (i % f) + (j * f)
|
||||
return (new_i, new_j)
|
||||
@@ -60,7 +60,6 @@ class PersistentDenseGemmKernel(CuteDslKernel):
|
||||
- i.e., Float8E4M3FN for A and Float8E5M2 for B is not supported
|
||||
|
||||
:note: Supported A/B data types:
|
||||
- TFloat32
|
||||
- Float16/BFloat16
|
||||
- Int8/Uint8
|
||||
- Float8E4M3FN/Float8E5M2
|
||||
@@ -85,7 +84,7 @@ class PersistentDenseGemmKernel(CuteDslKernel):
|
||||
"""
|
||||
|
||||
def __init__(self, metadata: KernelMetadata):
|
||||
self.metadata = metadata
|
||||
super().__init__(metadata)
|
||||
|
||||
def epilogue_op(x):
|
||||
return x
|
||||
@@ -106,6 +105,7 @@ class PersistentDenseGemmKernel(CuteDslKernel):
|
||||
)
|
||||
|
||||
def _supports(self, args: GemmArguments) -> Status:
|
||||
|
||||
if args.epilogue is not None:
|
||||
return Status.fail("This kernel does not support any epilogue fusion.")
|
||||
return Status.success()
|
||||
@@ -149,12 +149,10 @@ class PersistentDenseGemmKernel(CuteDslKernel):
|
||||
abtype = operands.A.dtype
|
||||
|
||||
# Supported A/B data types:
|
||||
# - TFloat32
|
||||
# - Float16/BFloat16
|
||||
# - Int8/Uint8
|
||||
# - Float8E4M3FN/Float8E5M2
|
||||
if abtype not in [
|
||||
cutlass.Float32,
|
||||
cutlass.Float16,
|
||||
cutlass.BFloat16,
|
||||
cutlass.Int8,
|
||||
@@ -197,7 +195,11 @@ class PersistentDenseGemmKernel(CuteDslKernel):
|
||||
operands.out.dtype == cutlass.Float16
|
||||
or operands.out.dtype == cutlass.BFloat16
|
||||
):
|
||||
if operands.accumulator_type not in [cutlass.Float16, cutlass.BFloat16]:
|
||||
if operands.accumulator_type not in [
|
||||
cutlass.Float32,
|
||||
cutlass.Float16,
|
||||
cutlass.BFloat16,
|
||||
]:
|
||||
return False
|
||||
elif operands.out.dtype == cutlass.Int8 or operands.out.dtype == cutlass.Uint8:
|
||||
if operands.accumulator_type not in [cutlass.Int32]:
|
||||
@@ -221,7 +223,6 @@ class PersistentDenseGemmKernel(CuteDslKernel):
|
||||
"""
|
||||
# Supported A/B data types (must be the same)
|
||||
ab_dtypes = [
|
||||
cutlass.Float32,
|
||||
cutlass.Float16,
|
||||
cutlass.BFloat16,
|
||||
cutlass.Int8,
|
||||
@@ -232,15 +233,13 @@ class PersistentDenseGemmKernel(CuteDslKernel):
|
||||
|
||||
row_major_stride = (0, 0, 1)
|
||||
col_major_stride = (0, 1, 0)
|
||||
alignment = 16
|
||||
alignment_bytes = 16
|
||||
|
||||
for ab_dtype in ab_dtypes:
|
||||
# Determine valid accumulator types for this A/B dtype
|
||||
valid_acc_dtypes = []
|
||||
|
||||
if (
|
||||
ab_dtype.is_float
|
||||
): # Float32, Float16, BFloat16, Float8E4M3FN, Float8E5M2
|
||||
if ab_dtype.is_float:
|
||||
valid_acc_dtypes.append(cutlass.Float32)
|
||||
if ab_dtype in [
|
||||
cutlass.Float16,
|
||||
@@ -277,15 +276,23 @@ class PersistentDenseGemmKernel(CuteDslKernel):
|
||||
for stride_A, stride_B, stride_out in itertools.product(
|
||||
[row_major_stride, col_major_stride], repeat=3
|
||||
):
|
||||
ab_divisibility = alignment_bytes * 8 // ab_dtype.width
|
||||
out_divisibility = alignment_bytes * 8 // out_dtype.width
|
||||
# Create TensorAttributes for A, B, and out tensors
|
||||
a_attrs = TensorAttributes(
|
||||
dtype=ab_dtype, stride=stride_A, alignment=alignment
|
||||
dtype=ab_dtype,
|
||||
stride=stride_A,
|
||||
divisibility=ab_divisibility,
|
||||
)
|
||||
b_attrs = TensorAttributes(
|
||||
dtype=ab_dtype, stride=stride_B, alignment=alignment
|
||||
dtype=ab_dtype,
|
||||
stride=stride_B,
|
||||
divisibility=ab_divisibility,
|
||||
)
|
||||
out_attrs = TensorAttributes(
|
||||
dtype=out_dtype, stride=stride_out, alignment=alignment
|
||||
dtype=out_dtype,
|
||||
stride=stride_out,
|
||||
divisibility=out_divisibility,
|
||||
)
|
||||
|
||||
# Create and yield the GemmOperandsMetadata
|
||||
@@ -316,26 +323,26 @@ class PersistentDenseGemmKernel(CuteDslKernel):
|
||||
if cluster_size_m * cluster_size_n > 16:
|
||||
return False
|
||||
|
||||
tile = design.tile_shape
|
||||
|
||||
# Constraints based on whether 2CTA instructions are used
|
||||
if design.use_2cta_mma is not None:
|
||||
if design.use_2cta_mma:
|
||||
if cluster_size_m % 2 != 0:
|
||||
return False
|
||||
if design.tile_shape is not None and design.tile_shape[0] not in [
|
||||
if tile is not None and tile[0] not in [
|
||||
128,
|
||||
256,
|
||||
]:
|
||||
return False
|
||||
else:
|
||||
if design.tile_shape is not None and design.tile_shape[0] not in [
|
||||
if tile is not None and tile[0] not in [
|
||||
64,
|
||||
128,
|
||||
]:
|
||||
return False
|
||||
|
||||
if design.tile_shape is not None and design.tile_shape[1] not in range(
|
||||
32, 256, 32
|
||||
):
|
||||
if tile is not None and tile[1] not in range(32, 257, 32):
|
||||
return False
|
||||
|
||||
if metadata.epilogue is not None:
|
||||
@@ -407,8 +414,6 @@ class PersistentDenseGemmKernel(CuteDslKernel):
|
||||
if PersistentDenseGemmKernel._valid_metadata(
|
||||
metadata
|
||||
) and metadata_filter(metadata):
|
||||
kernel_list.append(
|
||||
PersistentDenseGemmKernel(metadata)
|
||||
)
|
||||
kernel_list.append(PersistentDenseGemmKernel(metadata))
|
||||
|
||||
return kernel_list
|
||||
|
||||
@@ -128,7 +128,7 @@ class PersistentDenseGemmEFCKernel(CuteDslKernel):
|
||||
}
|
||||
|
||||
def __init__(self, metadata: KernelMetadata):
|
||||
self.metadata = metadata
|
||||
super().__init__(metadata)
|
||||
|
||||
if metadata.epilogue is not None:
|
||||
epilogue_op = EFCConverter.convert(
|
||||
@@ -249,7 +249,7 @@ class PersistentDenseGemmEFCKernel(CuteDslKernel):
|
||||
"""
|
||||
row_major_stride = (0, 0, 1)
|
||||
col_major_stride = (0, 1, 0)
|
||||
alignment = 16
|
||||
alignment_bytes = 16
|
||||
|
||||
for (
|
||||
ab_dtype,
|
||||
@@ -281,15 +281,23 @@ class PersistentDenseGemmEFCKernel(CuteDslKernel):
|
||||
for stride_A, stride_B, stride_out in itertools.product(
|
||||
[row_major_stride, col_major_stride], repeat=3
|
||||
):
|
||||
ab_divisibility = alignment_bytes * 8 // ab_dtype.width
|
||||
out_divisibility = alignment_bytes * 8 // out_dtype.width
|
||||
# Create TensorAttributes for A, B, and out tensors
|
||||
a_attrs = TensorAttributes(
|
||||
dtype=ab_dtype, stride=stride_A, alignment=alignment
|
||||
dtype=ab_dtype,
|
||||
stride=stride_A,
|
||||
divisibility=ab_divisibility,
|
||||
)
|
||||
b_attrs = TensorAttributes(
|
||||
dtype=ab_dtype, stride=stride_B, alignment=alignment
|
||||
dtype=ab_dtype,
|
||||
stride=stride_B,
|
||||
divisibility=ab_divisibility,
|
||||
)
|
||||
out_attrs = TensorAttributes(
|
||||
dtype=out_dtype, stride=stride_out, alignment=alignment
|
||||
dtype=out_dtype,
|
||||
stride=stride_out,
|
||||
divisibility=out_divisibility,
|
||||
)
|
||||
|
||||
# Create and yield the GemmOperandsMetadata
|
||||
@@ -303,10 +311,13 @@ class PersistentDenseGemmEFCKernel(CuteDslKernel):
|
||||
yield operands
|
||||
|
||||
def _supports(self, args: GemmArguments) -> Status:
|
||||
|
||||
if args.epilogue is not None:
|
||||
fusion_metadata = EpilogueMetadata.from_args(args.epilogue)
|
||||
if not self._valid_fusion(fusion_metadata):
|
||||
return Status.fail("Provided epilogue fusion is not supported by this kernel")
|
||||
return Status.fail(
|
||||
"Provided epilogue fusion is not supported by this kernel"
|
||||
)
|
||||
|
||||
return Status.success()
|
||||
|
||||
@@ -436,8 +447,6 @@ class PersistentDenseGemmEFCKernel(CuteDslKernel):
|
||||
if PersistentDenseGemmEFCKernel._valid_metadata(
|
||||
metadata
|
||||
) and metadata_filter(metadata):
|
||||
kernel_list.append(
|
||||
PersistentDenseGemmEFCKernel(metadata)
|
||||
)
|
||||
kernel_list.append(PersistentDenseGemmEFCKernel(metadata))
|
||||
|
||||
return kernel_list
|
||||
|
||||
@@ -0,0 +1,273 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import itertools
|
||||
from collections.abc import Callable, Generator
|
||||
|
||||
import cutlass
|
||||
|
||||
from cutlass_api.arguments import (
|
||||
EpilogueArguments,
|
||||
GemmArguments,
|
||||
)
|
||||
from cutlass_api.artifact import CompiledArtifact
|
||||
from cutlass_api.metadata import (
|
||||
GemmOperandsMetadata,
|
||||
KernelMetadata,
|
||||
TensorAttributes,
|
||||
)
|
||||
from cutlass_api.providers.cutedsl import CuTeDSLProvider
|
||||
from cutlass_api.providers.cutedsl.kernel import CuteDslKernel
|
||||
from cutlass_api.status import Status
|
||||
from cutlass_api.utils import strides_to_layout_string, to_cuda_stream, tuple_to_string
|
||||
from cutlass_api.metadata import BLASDesignMetadata
|
||||
from .implementations.sm80_tensorop_gemm_impl import Sm80TensorOpGemmImpl
|
||||
|
||||
|
||||
@CuTeDSLProvider.register
|
||||
class Sm80TensorOpGemmKernel(CuteDslKernel):
|
||||
"""This class implements batched matrix multiplication (C = A @ B)
|
||||
|
||||
:note: In current version, A and B tensor must have the same data type
|
||||
|
||||
:note: Supported A/B data types:
|
||||
- Float16/BFloat16
|
||||
|
||||
:note: Supported accumulator data types:
|
||||
- Float32 (for all floating point A/B data types)
|
||||
:note: Supported C data types:
|
||||
- Float16/BFloat16 (same as A, B)
|
||||
|
||||
:note: Constraints:
|
||||
- MMA tiler M must be 64/128
|
||||
- MMA tiler N must be 64/128
|
||||
"""
|
||||
|
||||
def __init__(self, metadata: KernelMetadata):
|
||||
super().__init__(metadata)
|
||||
self.impl = Sm80TensorOpGemmImpl(
|
||||
metadata.operands.A.dtype,
|
||||
metadata.operands.B.dtype,
|
||||
metadata.operands.out.dtype,
|
||||
metadata.operands.accumulator_type
|
||||
)
|
||||
|
||||
def _supports(self, args: GemmArguments) -> Status:
|
||||
|
||||
if args.epilogue is not None:
|
||||
return Status.fail("This kernel does not support any epilogue fusion.")
|
||||
return Status.success()
|
||||
|
||||
def compile(self, args: GemmArguments, cc: int = None) -> CompiledArtifact:
|
||||
stream = cutlass.cute.runtime.make_fake_stream()
|
||||
|
||||
compiled_kernel = self.cute_compile(
|
||||
self.impl,
|
||||
args.A,
|
||||
args.B,
|
||||
args.out,
|
||||
stream,
|
||||
)
|
||||
return CompiledArtifact(compiled_kernel, self)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
args: GemmArguments,
|
||||
compiled_artifact: CompiledArtifact,
|
||||
stream,
|
||||
workspace=None,
|
||||
) -> None:
|
||||
stream = to_cuda_stream(stream)
|
||||
compiled_gemm = compiled_artifact.compiled_obj
|
||||
self.cute_run(compiled_gemm, args.A, args.B, args.out, stream)
|
||||
|
||||
@staticmethod
|
||||
def _valid_operands(operands: GemmOperandsMetadata) -> bool:
|
||||
if not isinstance(operands, GemmOperandsMetadata):
|
||||
return False
|
||||
|
||||
# In current version, A and B tensor must have the same data type
|
||||
if operands.A.dtype != operands.B.dtype:
|
||||
return False
|
||||
|
||||
# Supported A/B data types:
|
||||
if operands.A.dtype not in [cutlass.Float16, cutlass.BFloat16]:
|
||||
return False
|
||||
|
||||
# Supported accumulator data types:
|
||||
if operands.accumulator_type not in [cutlass.Float32]:
|
||||
return False
|
||||
|
||||
# Supported out data types:
|
||||
if operands.out.dtype not in [cutlass.Float32, cutlass.Float16, cutlass.BFloat16]:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _metadata_operand_combinations() -> Generator[GemmOperandsMetadata, None, None]:
|
||||
"""
|
||||
Generator that yields all valid GemmOperandsMetadata combinations
|
||||
based on the validation rules in _valid_operands.
|
||||
"""
|
||||
# Supported A/B data types (must be the same)
|
||||
ab_dtypes = [cutlass.Float16, cutlass.BFloat16]
|
||||
valid_acc_dtypes = [cutlass.Float32]
|
||||
valid_out_dtypes = ab_dtypes
|
||||
|
||||
# Torch conventions (L, M, K) and (L, K, N)
|
||||
row_major_stride = (0, 0, 1)
|
||||
col_major_stride = (0, 1, 0)
|
||||
alignment_bytes = 16
|
||||
|
||||
for ab_dtype in ab_dtypes:
|
||||
for out_dtype in valid_out_dtypes:
|
||||
for acc_dtype in valid_acc_dtypes:
|
||||
for stride_A, stride_B, stride_out in itertools.product(
|
||||
[row_major_stride, col_major_stride], repeat=3
|
||||
):
|
||||
ab_divisibility = alignment_bytes * 8 // ab_dtype.width
|
||||
out_divisibility = alignment_bytes * 8 // out_dtype.width
|
||||
|
||||
# Create TensorAttributes for A, B, and out tensors
|
||||
a_attrs = TensorAttributes(
|
||||
dtype=ab_dtype,
|
||||
stride=stride_A,
|
||||
divisibility=ab_divisibility,
|
||||
)
|
||||
b_attrs = TensorAttributes(
|
||||
dtype=ab_dtype,
|
||||
stride=stride_B,
|
||||
divisibility=ab_divisibility,
|
||||
)
|
||||
out_attrs = TensorAttributes(
|
||||
dtype=out_dtype,
|
||||
stride=stride_out,
|
||||
divisibility=out_divisibility,
|
||||
)
|
||||
|
||||
# Create and yield the GemmOperandsMetadata
|
||||
operands = GemmOperandsMetadata(
|
||||
A=a_attrs,
|
||||
B=b_attrs,
|
||||
out=out_attrs,
|
||||
accumulator_type=acc_dtype,
|
||||
)
|
||||
|
||||
yield operands
|
||||
|
||||
@staticmethod
|
||||
def _valid_metadata(metadata: KernelMetadata) -> bool:
|
||||
if not Sm80TensorOpGemmKernel._valid_operands(metadata.operands):
|
||||
return False
|
||||
|
||||
design = metadata.design
|
||||
if not isinstance(design, BLASDesignMetadata):
|
||||
return False
|
||||
|
||||
if design.tile_shape is None:
|
||||
return False
|
||||
|
||||
# MMA tiler N must be 32/64/128/256
|
||||
if design.tile_shape[1] not in [32, 64, 128]:
|
||||
return False
|
||||
|
||||
# MMA tiler M must be 64/128/256
|
||||
if design.tile_shape[0] not in [64, 128]:
|
||||
return False
|
||||
|
||||
if metadata.epilogue is not None:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def generate_kernels(
|
||||
metadata_filter: Callable[[KernelMetadata], bool],
|
||||
epilogue_args: EpilogueArguments = None,
|
||||
cc: int = None,
|
||||
) -> list["Sm80TensorOpGemmKernel"]:
|
||||
"""
|
||||
Returns a list of all possible configurations of Sm80TensorOpGemmKernel that
|
||||
adhere to constraints passed in under kwargs.
|
||||
"""
|
||||
min_cc = 80
|
||||
if cc is not None and cc < min_cc:
|
||||
return []
|
||||
|
||||
design_params = {
|
||||
"tile_shape": [
|
||||
(M, N, 32) for M in [64, 128] for N in [64, 128]
|
||||
],
|
||||
# SM80 kernels do not currently use cluster_shape for tuning; fix it to a
|
||||
# single valid value to satisfy the BLASDesignMetadata interface.
|
||||
"cluster_shape": [(1, 1, 1)],
|
||||
}
|
||||
|
||||
if epilogue_args is not None:
|
||||
return []
|
||||
|
||||
from itertools import product
|
||||
|
||||
# Get the list of tunable parameter names and their possible values
|
||||
param_names = list(design_params.keys())
|
||||
param_values = [design_params[name] for name in param_names]
|
||||
|
||||
kernel_list = []
|
||||
|
||||
for operands in Sm80TensorOpGemmKernel._metadata_operand_combinations():
|
||||
for values in product(*param_values):
|
||||
design = BLASDesignMetadata(**dict(zip(param_names, values)))
|
||||
kernel_name = "cutedsl.Sm80TensorOpGemmKernel_{layout}_A{A}_B{B}_out{out}_acc{acc}_tile{tile}".format(
|
||||
layout=strides_to_layout_string(
|
||||
operands.A.stride, operands.B.stride, operands.out.stride
|
||||
),
|
||||
A=operands.A.dtype,
|
||||
B=operands.B.dtype,
|
||||
out=operands.out.dtype,
|
||||
acc=operands.accumulator_type,
|
||||
tile=tuple_to_string(design.tile_shape),
|
||||
)
|
||||
|
||||
metadata = KernelMetadata(
|
||||
operands=operands,
|
||||
design=design,
|
||||
kernel_name=kernel_name,
|
||||
kernel_class=Sm80TensorOpGemmKernel,
|
||||
min_cc=min_cc,
|
||||
epilogue=None,
|
||||
)
|
||||
|
||||
if Sm80TensorOpGemmKernel._valid_metadata(
|
||||
metadata
|
||||
) and metadata_filter(metadata):
|
||||
kernel_list.append(
|
||||
Sm80TensorOpGemmKernel(metadata)
|
||||
)
|
||||
|
||||
return kernel_list
|
||||
@@ -26,12 +26,12 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import cutlass.cute as cute
|
||||
|
||||
from cutlass_api.config import GlobalOptions
|
||||
from cutlass_api.kernel import Kernel
|
||||
from cutlass_api.utils import TensorWrapper
|
||||
|
||||
import cutlass.cute as cute
|
||||
|
||||
|
||||
class CuteDslKernel(Kernel):
|
||||
"""
|
||||
|
||||
@@ -58,4 +58,3 @@ class Status:
|
||||
"""Raise the stored exception if this status represents a failure."""
|
||||
if self.error is not None:
|
||||
raise self.error
|
||||
|
||||
|
||||
@@ -308,21 +308,54 @@ def leading_dim(tensor) -> int:
|
||||
|
||||
def get_stride_order(stride: tuple[int, ...]) -> tuple[int, ...]:
|
||||
"""
|
||||
Returns the order of the stride of a tensor. For a stride of rank N,
|
||||
the dimension with the smallest stride will have stride order 0 and the
|
||||
dimension with the largest stride will have stride order N-1.
|
||||
Returns the order of the modes of a stride. The returned tuple contains
|
||||
indices of modes in `stride`, sorted from outermost to innermost.
|
||||
|
||||
For example, for a stride (1024, 1, 512), the stride order is (0, 2, 1).
|
||||
|
||||
:param stride: The stride of the tensor.
|
||||
:type stride: tuple[int, ...]
|
||||
|
||||
:return: The order of the stride of the tensor.
|
||||
:return: The order of the modes of the stride.
|
||||
:rtype: tuple[int, ...]
|
||||
"""
|
||||
# The code below performs an == reverse argsort on the stride:
|
||||
# indices = range(len(stride))
|
||||
# Sort indices using comparison between stride[i] and stride[j] when
|
||||
# sorting indices i and j. Sort in descending order.
|
||||
# Example: For a stride (1024, 1, 512), the reverse argsort is (0, 2, 1).
|
||||
return tuple(sorted(range(len(stride)), key=stride.__getitem__, reverse=True))
|
||||
|
||||
|
||||
def get_stride_rank(stride: tuple[int, ...]) -> tuple[int, ...]:
|
||||
"""
|
||||
Returns the rank of the each mode in the stride of a tensor. For a stride of rank N,
|
||||
the mode with the smallest stride will have stride rank 0 and the
|
||||
mode with the largest stride will have stride rank N-1.
|
||||
|
||||
For example, for a stride (1024, 1, 512), the stride rank is (2, 0, 1).
|
||||
|
||||
:param stride: The stride of the tensor.
|
||||
:type stride: tuple[int, ...]
|
||||
|
||||
:return: The rank of the each mode in the stride of the tensor.
|
||||
:rtype: tuple[int, ...]
|
||||
"""
|
||||
# The code below performs an argsort on the stride:
|
||||
# indices = range(len(stride))
|
||||
# Sort indices using comparison between stride[i] and stride[j] when
|
||||
# sorting indices i and j.
|
||||
return tuple(sorted(range(len(stride)), key=stride.__getitem__))
|
||||
# Example: For a stride (1024, 1, 512), the argsorted is (1, 2, 0).
|
||||
argsorted = tuple(sorted(range(len(stride)), key=stride.__getitem__))
|
||||
|
||||
# Set the stride order of each mode to the index of the mode in the
|
||||
# argsorted tuple.
|
||||
# Example: For a stride (1024, 1, 512), the argsorted is (1, 2, 0),
|
||||
# so the stride rank is (2, 0, 1).
|
||||
res = [-1] * len(stride)
|
||||
for i, idx in enumerate(argsorted):
|
||||
res[idx] = i
|
||||
return tuple(res)
|
||||
|
||||
|
||||
class TensorWrapper:
|
||||
@@ -353,6 +386,7 @@ class TensorWrapper:
|
||||
self.compile_time_tensor = tensor
|
||||
self._shape = tensor.shape
|
||||
self._stride = tensor.stride
|
||||
self._data_ptr = tensor.iterator._pointer
|
||||
elif GlobalOptions().use_tvm_ffi:
|
||||
# If TVM-FFI is enabled, runtime tensor is set simply as the tensor passed in, but
|
||||
# we must make a fake tensor for compilation.
|
||||
@@ -362,28 +396,48 @@ class TensorWrapper:
|
||||
|
||||
rank = self.runtime_tensor.dim()
|
||||
self._stride = self.runtime_tensor.stride()
|
||||
stride_order = get_stride_order(self._stride)
|
||||
|
||||
stride_order = get_stride_rank(self._stride)
|
||||
leading_dim_idx = stride_order.index(0)
|
||||
shape = [cute.SymInt() for _ in range(rank)]
|
||||
shape[stride_order.index(0)] = cute.SymInt(divisibility=16)
|
||||
shape[leading_dim_idx] = cute.SymInt(divisibility=16 * 8 // dtype.width)
|
||||
self._shape = tuple(self.runtime_tensor.shape)
|
||||
self._data_ptr = self.runtime_tensor.data_ptr()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported tensor type: {type(self.runtime_tensor)}"
|
||||
)
|
||||
|
||||
self.compile_time_tensor = cute.runtime.make_fake_compact_tensor(
|
||||
dtype, shape, stride_order=stride_order, assumed_align=16
|
||||
dtype,
|
||||
shape,
|
||||
stride_order=stride_order,
|
||||
assumed_align=16, # bytes
|
||||
)
|
||||
else:
|
||||
# TVM-FFI is disabled and the tensor passed in is not a cute.Tensor,
|
||||
# We must convert it to a cute.Tensor
|
||||
self.runtime_tensor = from_dlpack(
|
||||
tensor, assumed_align=16
|
||||
).mark_layout_dynamic(leading_dim(tensor))
|
||||
if is_torch_tensor(tensor):
|
||||
dtype = to_cutlass_type(tensor.dtype)
|
||||
stride = tensor.stride()
|
||||
else:
|
||||
raise ValueError(f"Unsupported tensor type: {type(tensor)}")
|
||||
stride_order = get_stride_order(stride)
|
||||
self.runtime_tensor = (
|
||||
from_dlpack(
|
||||
tensor,
|
||||
assumed_align=16, # bytes
|
||||
)
|
||||
.mark_layout_dynamic(leading_dim(tensor))
|
||||
.mark_compact_shape_dynamic(
|
||||
mode=leading_dim(tensor),
|
||||
divisibility=16 * 8 // dtype.width,
|
||||
stride_order=stride_order,
|
||||
)
|
||||
)
|
||||
|
||||
self._shape = self.runtime_tensor.shape
|
||||
self._stride = self.runtime_tensor.stride
|
||||
self._data_ptr = self.runtime_tensor.iterator._pointer
|
||||
|
||||
# Since the runtime tensor is now a cute.Tensor, we can use it at
|
||||
# compile time as well
|
||||
@@ -401,6 +455,10 @@ class TensorWrapper:
|
||||
def stride(self) -> tuple[int, ...]:
|
||||
return self._stride
|
||||
|
||||
@property
|
||||
def data_ptr(self) -> int:
|
||||
return self._data_ptr
|
||||
|
||||
|
||||
def strides_to_layout_string(*strides: list[tuple[int, ...]]) -> str:
|
||||
"""
|
||||
|
||||
@@ -31,9 +31,9 @@
|
||||
"\n",
|
||||
"import cutlass_api\n",
|
||||
"\n",
|
||||
"if not (status := cutlass_api.utils.is_device_cc_supported({100, 103})):\n",
|
||||
"if not (status := cutlass_api.utils.is_device_cc_supported({80, 90, 100, 103})):\n",
|
||||
" print(\n",
|
||||
" f\"This notebook requires a GPU with compute capability 100 or 103.\\n{status.error}\"\n",
|
||||
" f\"This notebook requires a GPU with compute capability >= 80.\\n{status.error}\"\n",
|
||||
" )\n",
|
||||
" import sys\n",
|
||||
"\n",
|
||||
@@ -67,7 +67,7 @@
|
||||
"source": [
|
||||
"M, N, K, L = 128, 256, 64, 2\n",
|
||||
"ab_type = torch.float16\n",
|
||||
"out_type = torch.float32\n",
|
||||
"out_type = torch.float16\n",
|
||||
"acc_type = torch.float32\n",
|
||||
"\n",
|
||||
"A = torch.randint(-1, 2, (L, M, K), device=\"cuda\", dtype=ab_type)\n",
|
||||
@@ -118,9 +118,9 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"kernels = cutlass_api.get_kernels(args)\n",
|
||||
"cc = cutlass_api.utils.device_cc()\n",
|
||||
"kernels = cutlass_api.get_kernels(args, cc=cc)\n",
|
||||
"assert kernels, \"No kernels found for the given arguments!\"\n",
|
||||
"\n",
|
||||
"kernel = kernels[0]"
|
||||
]
|
||||
},
|
||||
@@ -148,28 +148,6 @@
|
||||
"torch.testing.assert_close(out, reference)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4d7ad85b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"One can also explicitly compile the kernel and pass this in to `kernel.run` to avoid\n",
|
||||
"JIT compilation on future invocations. Additional details related to this will be\n",
|
||||
"described below."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "06f9f844",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"artifact = kernel.compile(args)\n",
|
||||
"kernel.run(args, compiled_artifact=artifact)\n",
|
||||
"torch.testing.assert_close(out, reference)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "630e9e4b",
|
||||
@@ -179,7 +157,7 @@
|
||||
"\n",
|
||||
"---\n",
|
||||
"\n",
|
||||
"### Understanding the core interfaces"
|
||||
"## Understanding the core interfaces"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -187,7 +165,7 @@
|
||||
"id": "2d8b8e94",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### 1. `RuntimeArguments` / `GemmArguments`\n",
|
||||
"### 1. `RuntimeArguments` / `GemmArguments`\n",
|
||||
"\n",
|
||||
"`RuntimeArguments` describe the operation a user wants to perform, and all the runtime operands or other runtime parameters needed for it. \n",
|
||||
"This includes primary runtime operands to the operation, as well as any custom epilogue fusions and runtime performance knobs.\n",
|
||||
@@ -220,7 +198,7 @@
|
||||
"id": "e7eda0dd",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### 2. Kernel Discovery\n",
|
||||
"### 2. Kernel Discovery\n",
|
||||
"\n",
|
||||
"There are several kernels available in CUTLASS DSLs that are registered with, and discoverable via, the CUTLASS API.\n",
|
||||
"\n",
|
||||
@@ -254,6 +232,11 @@
|
||||
"kernels = cutlass_api.get_kernels(args)\n",
|
||||
"print(f\"Of these, {len(kernels)} support the given arguments.\")\n",
|
||||
"\n",
|
||||
"# we can limit the search to kernels supporting given args + current device compute capability\n",
|
||||
"cc = cutlass_api.utils.device_cc()\n",
|
||||
"kernels = cutlass_api.get_kernels(args, cc=cc)\n",
|
||||
"print(f\"Of these, {len(kernels)} support the given arguments.\")\n",
|
||||
"\n",
|
||||
"kernel = kernels[0]\n",
|
||||
"print(f\"Picked kernel with name: {kernel.metadata.kernel_name}\")"
|
||||
]
|
||||
@@ -263,7 +246,7 @@
|
||||
"id": "252a4d38",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### 3. `Kernel` execution"
|
||||
"### 3. `Kernel` execution"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -286,7 +269,8 @@
|
||||
"id": "e8945aa6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"* `kernel.supports(args)` checks if the kernel supports the given `args`\n",
|
||||
"#### Verify that the kernel supports the given `args`\n",
|
||||
"`kernel.supports(args)` checks if the kernel supports the given `args`\n",
|
||||
" * this is relevant if the kernel was not picked just for these `args`"
|
||||
]
|
||||
},
|
||||
@@ -338,7 +322,9 @@
|
||||
"id": "c2db8f20",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"* `kernel.compile(args)` compiles the kernel, and returns a `CompiledArtifact`\n",
|
||||
"#### JIT compiling the kernel\n",
|
||||
"\n",
|
||||
"`kernel.compile(args)` compiles the kernel, and returns a `CompiledArtifact`\n",
|
||||
"\n",
|
||||
"This compiled artifact is a lightweight wrapper over the result of compiling a kernel (e.g., via `cute.compile()`).\n",
|
||||
"\n",
|
||||
@@ -361,7 +347,8 @@
|
||||
"id": "4dfb8d51",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"* `kernel.run(args)` launches the compiled kernel function. This example uses:\n",
|
||||
"#### Launching the compiled kernel function\n",
|
||||
"`kernel.run(args)` launches the compiled kernel function. The next example uses:\n",
|
||||
" * the precompiled artifact\n",
|
||||
" * a custom stream to launch to\n",
|
||||
" * bypasses the supports check already performed above (`assume_supported_args=True`)."
|
||||
@@ -386,6 +373,24 @@
|
||||
"torch.testing.assert_close(out, reference)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7956813d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Passing in a precompiled kernel is critical to achieving good performance because it avoids\n",
|
||||
"JIT compiling the kernel on each invocation. JIT compilation always occurs when a precompiled\n",
|
||||
"kernel is not provided in the call to `kernel.run()`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c228495a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Workspace Buffers"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f67eeb8f",
|
||||
@@ -416,7 +421,7 @@
|
||||
"id": "baffaf12",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Advanced: Filtering on Metadata"
|
||||
"## Advanced: Filtering on Metadata"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 1,
|
||||
"id": "bb450878",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -20,7 +20,9 @@
|
||||
"import cutlass_api\n",
|
||||
"\n",
|
||||
"if not (status := cutlass_api.utils.is_device_cc_supported({100, 103})):\n",
|
||||
" print(f\"This notebook requires a GPU with compute capability 100 or 103.\\n{status.error}\")\n",
|
||||
" print(\n",
|
||||
" f\"This notebook requires a GPU with compute capability 100 or 103.\\n{status.error}\"\n",
|
||||
" )\n",
|
||||
" import sys\n",
|
||||
"\n",
|
||||
" sys.exit(0)"
|
||||
@@ -42,7 +44,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 2,
|
||||
"id": "e6d77d53",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -56,13 +58,15 @@
|
||||
"B = torch.randn(L, K, N, device=\"cuda\", dtype=torch.float16)\n",
|
||||
"C = torch.randn(L, M, N, device=\"cuda\", dtype=torch.float16)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def my_epilogue(accum, C, alpha, beta, extra_scalar):\n",
|
||||
" Aux = (alpha * accum) + (beta * C)\n",
|
||||
" D = extra_scalar * Aux\n",
|
||||
" return D, Aux\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"alpha, beta, extra_scalar = 1.0, 2.0, 0.5\n",
|
||||
"D, Aux = my_epilogue(A @ B, C, alpha, beta, extra_scalar)\n"
|
||||
"D, Aux = my_epilogue(A @ B, C, alpha, beta, extra_scalar)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -77,18 +81,28 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 3,
|
||||
"id": "f079d9d6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import cutlass_api\n",
|
||||
"from cutlass_api.arguments import GemmArguments, EpilogueArguments\n",
|
||||
"from cutlass_api.arguments import EpilogueArguments, GemmArguments\n",
|
||||
"\n",
|
||||
"# Allocate buffers for D and Aux\n",
|
||||
"D_, Aux_ = [torch.empty((L, M, N), device=\"cuda\", dtype=torch.float16) for _ in range(2)]\n",
|
||||
"D_, Aux_ = [\n",
|
||||
" torch.empty((L, M, N), device=\"cuda\", dtype=torch.float16) for _ in range(2)\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"epi_args = EpilogueArguments(my_epilogue, C=C, alpha=alpha, beta=beta, extra_scalar=extra_scalar, D=D_, Aux=Aux_)\n"
|
||||
"epi_args = EpilogueArguments(\n",
|
||||
" my_epilogue,\n",
|
||||
" C=C,\n",
|
||||
" alpha=alpha,\n",
|
||||
" beta=beta,\n",
|
||||
" extra_scalar=extra_scalar,\n",
|
||||
" D=D_,\n",
|
||||
" Aux=Aux_,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -101,14 +115,15 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 4,
|
||||
"id": "60215c4e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
|
||||
"kernels = cutlass_api.get_kernels(args)\n",
|
||||
"assert len(kernels) > 0\n"
|
||||
"cc = cutlass_api.utils.device_cc()\n",
|
||||
"kernels = cutlass_api.get_kernels(args, cc=cc)\n",
|
||||
"assert len(kernels) > 0"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -122,7 +137,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 5,
|
||||
"id": "150f3296",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -130,7 +145,7 @@
|
||||
"kernels[0].run(args)\n",
|
||||
"\n",
|
||||
"torch.testing.assert_close(D, D_)\n",
|
||||
"torch.testing.assert_close(Aux, Aux_)\n"
|
||||
"torch.testing.assert_close(Aux, Aux_)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -289,17 +304,19 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 6,
|
||||
"id": "171ac178",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from cutlass_api.fusion.activation import relu\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def relu_aux_store(accum, alpha, C):\n",
|
||||
" F = (accum * alpha) + (C * 2.0) # Constant beta\n",
|
||||
" D = relu(F)\n",
|
||||
" return D, F\n",
|
||||
" F = (accum * alpha) + (C * 2.0) # Constant beta\n",
|
||||
" D = relu(F)\n",
|
||||
" return D, F\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"C = torch.randn((L, M, N), device=\"cuda\", dtype=torch.float16)\n",
|
||||
"alpha = 3.0\n",
|
||||
@@ -308,14 +325,14 @@
|
||||
"\n",
|
||||
"epi_args = EpilogueArguments(relu_aux_store, alpha=alpha, C=C, D=D, F=F)\n",
|
||||
"args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
|
||||
"kernels = cutlass_api.get_kernels(args, cc=100)\n",
|
||||
"kernels = cutlass_api.get_kernels(args, cc=cc)\n",
|
||||
"assert len(kernels) > 0\n",
|
||||
"kernels[0].run(args)\n",
|
||||
"\n",
|
||||
"D_ref, F_ref = relu_aux_store(A @ B, alpha, C)\n",
|
||||
"\n",
|
||||
"torch.testing.assert_close(D, D_ref)\n",
|
||||
"torch.testing.assert_close(F, F_ref)\n"
|
||||
"torch.testing.assert_close(F, F_ref)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -328,15 +345,16 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 7,
|
||||
"id": "62c2b49b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def relu_scale_return_acc(accum, alpha, beta, C, scale):\n",
|
||||
" F = relu((accum * alpha) + (C * beta))\n",
|
||||
" D = F * scale\n",
|
||||
" return D, F, accum\n",
|
||||
" F = relu((accum * alpha) + (C * beta))\n",
|
||||
" D = F * scale\n",
|
||||
" return D, F, accum\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"C = torch.randn((L, M, N), device=\"cuda\", dtype=torch.float16)\n",
|
||||
"alpha = 1.0\n",
|
||||
@@ -346,9 +364,18 @@
|
||||
"F = torch.empty((L, M, N), device=\"cuda\", dtype=torch.float16)\n",
|
||||
"accum = torch.empty((L, M, N), device=\"cuda\", dtype=torch.float32)\n",
|
||||
"\n",
|
||||
"epi_args = EpilogueArguments(relu_scale_return_acc, alpha=alpha, beta=beta, C=C, scale=scale, D=D, F=F, accum=accum)\n",
|
||||
"epi_args = EpilogueArguments(\n",
|
||||
" relu_scale_return_acc,\n",
|
||||
" alpha=alpha,\n",
|
||||
" beta=beta,\n",
|
||||
" C=C,\n",
|
||||
" scale=scale,\n",
|
||||
" D=D,\n",
|
||||
" F=F,\n",
|
||||
" accum=accum,\n",
|
||||
")\n",
|
||||
"args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
|
||||
"kernels = cutlass_api.get_kernels(args, cc=100)\n",
|
||||
"kernels = cutlass_api.get_kernels(args, cc=cc)\n",
|
||||
"assert len(kernels) > 0\n",
|
||||
"kernels[0].run(args)\n",
|
||||
"\n",
|
||||
@@ -356,7 +383,7 @@
|
||||
"\n",
|
||||
"torch.testing.assert_close(D, D_ref)\n",
|
||||
"torch.testing.assert_close(F, F_ref)\n",
|
||||
"torch.testing.assert_close(accum, accum_ref.to(accum.dtype))\n"
|
||||
"torch.testing.assert_close(accum, accum_ref.to(accum.dtype))"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -370,7 +397,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 8,
|
||||
"id": "5987bf44",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -385,7 +412,7 @@
|
||||
"\n",
|
||||
"epi_args = EpilogueArguments(epi_str, alpha=alpha, beta=beta, C=C, D=D, F=F)\n",
|
||||
"args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
|
||||
"kernels = cutlass_api.get_kernels(args, cc=100)\n",
|
||||
"kernels = cutlass_api.get_kernels(args, cc=cc)\n",
|
||||
"assert len(kernels) > 0\n",
|
||||
"kernels[0].run(args)\n",
|
||||
"\n",
|
||||
@@ -393,7 +420,7 @@
|
||||
"D_ref = torch.relu(F_ref)\n",
|
||||
"\n",
|
||||
"torch.testing.assert_close(D, D_ref)\n",
|
||||
"torch.testing.assert_close(F, F_ref)\n"
|
||||
"torch.testing.assert_close(F, F_ref)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -407,54 +434,90 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 9,
|
||||
"id": "1e3d0c89",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"accum must be an input to the epilogue function\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"####################################################\n",
|
||||
"# Epilogues must take in an accumulator\n",
|
||||
"####################################################\n",
|
||||
"def fail_missing_accum(alpha, beta, C):\n",
|
||||
" D = (C * beta)\n",
|
||||
" return D\n",
|
||||
" D = C * beta\n",
|
||||
" return D\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"try:\n",
|
||||
" epi_args = EpilogueArguments(fail_missing_accum, alpha=alpha, beta=beta, C=C, D=D)\n",
|
||||
" args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
|
||||
" epi_args = EpilogueArguments(\n",
|
||||
" fail_missing_accum,\n",
|
||||
" alpha=alpha,\n",
|
||||
" beta=beta,\n",
|
||||
" C=C,\n",
|
||||
" D=D,\n",
|
||||
" )\n",
|
||||
" args = GemmArguments(\n",
|
||||
" A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args\n",
|
||||
" )\n",
|
||||
"except Exception as e:\n",
|
||||
" # \"accum must be an input to the epilogue function\"\n",
|
||||
" print(e)\n"
|
||||
" # \"accum must be an input to the epilogue function\"\n",
|
||||
" print(e)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 10,
|
||||
"id": "48a359f7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Output node D is not found in the epilogue function\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"####################################################\n",
|
||||
"# Epilogues must return an output named D\n",
|
||||
"####################################################\n",
|
||||
"def fail_missing_D(accum, alpha, beta, C):\n",
|
||||
" F = (accum * alpha) + (C * beta)\n",
|
||||
" return F\n",
|
||||
" F = (accum * alpha) + (C * beta)\n",
|
||||
" return F\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"try:\n",
|
||||
" epi_args = EpilogueArguments(fail_missing_D, alpha=alpha, beta=beta, C=C, F=F)\n",
|
||||
" args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
|
||||
" epi_args = EpilogueArguments(fail_missing_D, alpha=alpha, beta=beta, C=C, F=F)\n",
|
||||
" args = GemmArguments(\n",
|
||||
" A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args\n",
|
||||
" )\n",
|
||||
"except Exception as e:\n",
|
||||
" # \"On SM90 or higher, D is expected to be a output node with 0 users to enable smem reuse between C and D, but got []\"\n",
|
||||
" print(e)\n"
|
||||
" # \"Output node D is not found in the epilogue function\n",
|
||||
" print(e)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 11,
|
||||
"id": "49d9ee94",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Variable 'tmp' cannot be defined twice.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"####################################################\n",
|
||||
"# Epilogues must use single-static assignment (SSA)\n",
|
||||
@@ -466,51 +529,81 @@
|
||||
" D = tmp / 4.0\n",
|
||||
" return D, tmp\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"try:\n",
|
||||
" epi_args = EpilogueArguments(fail_ssa, D=D, tmp=F)\n",
|
||||
" args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
|
||||
" epi_args = EpilogueArguments(fail_ssa, D=D, tmp=F)\n",
|
||||
" args = GemmArguments(\n",
|
||||
" A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args\n",
|
||||
" )\n",
|
||||
"except Exception as e:\n",
|
||||
" # \"Variable 'tmp' cannot be defined twice.\"\n",
|
||||
" print(e)\n"
|
||||
" # \"Variable 'tmp' cannot be defined twice.\"\n",
|
||||
" print(e)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 12,
|
||||
"id": "871bb727",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Argument D is not provided in the kwargs of the EpilogueArguments constructor\n",
|
||||
"Argument alpha is not provided in the kwargs of the EpilogueArguments constructor\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"####################################################\n",
|
||||
"# Must provide all operands and outputs to\n",
|
||||
"# EpilogueArguments\n",
|
||||
"####################################################\n",
|
||||
"def my_epi(accum, alpha, beta, C):\n",
|
||||
" F = (accum * alpha) + (C * beta)\n",
|
||||
" D = relu(F)\n",
|
||||
" return D\n",
|
||||
" F = (accum * alpha) + (C * beta)\n",
|
||||
" D = relu(F)\n",
|
||||
" return D\n",
|
||||
"\n",
|
||||
"try:\n",
|
||||
" # Missing D\n",
|
||||
" epi_args = EpilogueArguments(my_epi, alpha=alpha, beta=beta, C=C)\n",
|
||||
" args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
|
||||
" # Missing D\n",
|
||||
" epi_args = EpilogueArguments(my_epi, alpha=alpha, beta=beta, C=C)\n",
|
||||
" args = GemmArguments(\n",
|
||||
" A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args\n",
|
||||
" )\n",
|
||||
"except Exception as e:\n",
|
||||
" # \"Argument D is not provided in the kwargs of the EpilogueArguments constructor\"\n",
|
||||
" print(e)\n",
|
||||
" # \"Argument D is not provided in the kwargs of the EpilogueArguments constructor\"\n",
|
||||
" print(e)\n",
|
||||
"\n",
|
||||
"try:\n",
|
||||
" # Missing alpha\n",
|
||||
" epi_args = EpilogueArguments(my_epi, beta=beta, C=C, D=D)\n",
|
||||
" args = GemmArguments(A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args)\n",
|
||||
" # Missing alpha\n",
|
||||
" epi_args = EpilogueArguments(my_epi, beta=beta, C=C, D=D)\n",
|
||||
" args = GemmArguments(\n",
|
||||
" A=A, B=B, out=D, accumulator_type=torch.float32, epilogue=epi_args\n",
|
||||
" )\n",
|
||||
"except Exception as e:\n",
|
||||
" # \"Argument alpha is not provided in the kwargs of the EpilogueArguments constructor\"\n",
|
||||
" print(e)\n"
|
||||
" # \"Argument alpha is not provided in the kwargs of the EpilogueArguments constructor\"\n",
|
||||
" print(e)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -24,12 +24,12 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 1,
|
||||
"id": "5a64b0be",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from typing import Callable\n",
|
||||
"from collections.abc import Callable\n",
|
||||
"\n",
|
||||
"import cuda.bindings.driver as cuda\n",
|
||||
"\n",
|
||||
@@ -110,7 +110,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 2,
|
||||
"id": "1a2da869",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -148,18 +148,34 @@
|
||||
"class F64GemmKernel(cutlass_api.providers.cutedsl.kernel.CuteDslKernel):\n",
|
||||
" # Empty versions of interface methods. These will be implemented later, interspersed\n",
|
||||
" # with notebook markdown. Normally, one would define them inline with the class definition.\n",
|
||||
" def __init__(self, metadata: KernelMetadata): pass\n",
|
||||
" def __init__(self, metadata: KernelMetadata):\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
" def _run(self, args: GemmArguments, artifact: cutlass_api.artifact.CompiledArtifact, stream, workspace=None): pass\n",
|
||||
" def _run(\n",
|
||||
" self,\n",
|
||||
" args: GemmArguments,\n",
|
||||
" artifact: cutlass_api.artifact.CompiledArtifact,\n",
|
||||
" stream,\n",
|
||||
" workspace=None,\n",
|
||||
" ):\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
" def compile(self, args: GemmArguments, cc: int = None) -> cutlass_api.artifact.CompiledArtifact: pass\n",
|
||||
" def compile(\n",
|
||||
" self, args: GemmArguments, cc: int = None\n",
|
||||
" ) -> cutlass_api.artifact.CompiledArtifact:\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
" @staticmethod\n",
|
||||
" def generate_kernels(metadata_filter, epilogue_args=None, cc=None) -> list[\"F64GemmKernel\"]: pass\n",
|
||||
" def generate_kernels(\n",
|
||||
" metadata_filter, epilogue_args=None, cc=None\n",
|
||||
" ) -> list[\"F64GemmKernel\"]:\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
" def _supports(self, args: GemmArguments) -> Status: pass\n",
|
||||
" def _supports(self, args: GemmArguments) -> Status:\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
" def get_workspace_size(self, args: GemmArguments) -> int: pass"
|
||||
" def get_workspace_size(self, args: GemmArguments) -> int:\n",
|
||||
" pass"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -175,13 +191,14 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 4,
|
||||
"id": "785d1882",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def __init__(self, metadata: KernelMetadata):\n",
|
||||
" self.metadata = metadata\n",
|
||||
" # Using Python-2-style super() because we're defining this method outside of the class definition.\n",
|
||||
" super(F64GemmKernel, self).__init__(metadata)\n",
|
||||
" cta_tile_shape_mn = metadata.design.tile_shape[:2]\n",
|
||||
" self.impl = F64GemmKernelImplementation(cta_tile_shape_mn)"
|
||||
]
|
||||
@@ -203,12 +220,14 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 5,
|
||||
"id": "63b4a129",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def compile(self, args: GemmArguments, cc: int = None) -> cutlass_api.artifact.CompiledArtifact:\n",
|
||||
"def compile(\n",
|
||||
" self, args: GemmArguments, cc: int = None\n",
|
||||
") -> cutlass_api.artifact.CompiledArtifact:\n",
|
||||
" stream = cutlass.cute.runtime.make_fake_stream()\n",
|
||||
" compiled_gemm = self.cute_compile(self.impl, args.A, args.B, args.out, stream)\n",
|
||||
" return cutlass_api.artifact.CompiledArtifact(compiled_gemm, self)"
|
||||
@@ -227,12 +246,18 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 6,
|
||||
"id": "2ae7c009",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def _run(self, args: GemmArguments, artifact: cutlass_api.artifact.CompiledArtifact, stream, workspace=None):\n",
|
||||
"def _run(\n",
|
||||
" self,\n",
|
||||
" args: GemmArguments,\n",
|
||||
" artifact: cutlass_api.artifact.CompiledArtifact,\n",
|
||||
" stream,\n",
|
||||
" workspace=None,\n",
|
||||
"):\n",
|
||||
" stream = cutlass_api.utils.to_cuda_stream(stream)\n",
|
||||
" compiled_gemm = artifact.compiled_obj\n",
|
||||
" self.cute_run(compiled_gemm, args.A, args.B, args.out, stream)"
|
||||
@@ -249,7 +274,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 7,
|
||||
"id": "968906ea",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -286,7 +311,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 8,
|
||||
"id": "47dc2f20",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -297,7 +322,6 @@
|
||||
" epilogue_args: cutlass_api.arguments.EpilogueArguments = None,\n",
|
||||
" cc: int = None,\n",
|
||||
") -> list[\"F64GemmKernel\"]:\n",
|
||||
"\n",
|
||||
" # The tile shapes this kernel supports/exposes\n",
|
||||
" supported_tile_shapes = [(32, 32, 1), (16, 16, 1)]\n",
|
||||
"\n",
|
||||
@@ -306,10 +330,12 @@
|
||||
"\n",
|
||||
" row_major_stride = (0, 0, 1)\n",
|
||||
" col_major_stride = (0, 1, 0)\n",
|
||||
" stride_combos = list(itertools.product([row_major_stride, col_major_stride], repeat=3))\n",
|
||||
" alignment = 1\n",
|
||||
" stride_combos = list(\n",
|
||||
" itertools.product([row_major_stride, col_major_stride], repeat=3)\n",
|
||||
" )\n",
|
||||
" divisibility = 1\n",
|
||||
"\n",
|
||||
" def stride_name(stride): \n",
|
||||
" def stride_name(stride):\n",
|
||||
" return \"T\" if stride == row_major_stride else \"N\"\n",
|
||||
"\n",
|
||||
" kernels = []\n",
|
||||
@@ -317,10 +343,18 @@
|
||||
" design_metadata = cutlass_api.metadata.BLASDesignMetadata(tile_shape, (1, 1, 1))\n",
|
||||
" for stride_A, stride_B, stride_out in stride_combos:\n",
|
||||
" # Create TensorAttributes for A, B, and out tensors\n",
|
||||
" a_attrs = cutlass_api.metadata.TensorAttributes(cutlass.Float64, stride_A, alignment)\n",
|
||||
" b_attrs = cutlass_api.metadata.TensorAttributes(cutlass.Float64, stride_B, alignment)\n",
|
||||
" out_attrs = cutlass_api.metadata.TensorAttributes(cutlass.Float64, stride_out, alignment)\n",
|
||||
" layout_str = cutlass_api.utils.strides_to_layout_string(stride_A, stride_B, stride_out)\n",
|
||||
" a_attrs = cutlass_api.metadata.TensorAttributes(\n",
|
||||
" cutlass.Float64, stride_A, divisibility\n",
|
||||
" )\n",
|
||||
" b_attrs = cutlass_api.metadata.TensorAttributes(\n",
|
||||
" cutlass.Float64, stride_B, divisibility\n",
|
||||
" )\n",
|
||||
" out_attrs = cutlass_api.metadata.TensorAttributes(\n",
|
||||
" cutlass.Float64, stride_out, divisibility\n",
|
||||
" )\n",
|
||||
" layout_str = cutlass_api.utils.strides_to_layout_string(\n",
|
||||
" stride_A, stride_B, stride_out\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" name = f\"F64GemmKernel_tile{tile_shape[0]}x{tile_shape[1]}_{layout_str}\"\n",
|
||||
"\n",
|
||||
@@ -355,24 +389,24 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 9,
|
||||
"id": "54067d47",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def _supports(self, args: GemmArguments) -> Status:\n",
|
||||
" if not (\n",
|
||||
" len(args.A.shape) == 3 and # A should be (L, M, K)\n",
|
||||
" len(args.B.shape) == 3 and # B should be (L, K, N)\n",
|
||||
" len(args.out.shape) == 3 # out should be (L, M, N)\n",
|
||||
" len(args.A.shape) == 3 # A should be (L, M, K)\n",
|
||||
" and len(args.B.shape) == 3 # B should be (L, K, N)\n",
|
||||
" and len(args.out.shape) == 3 # out should be (L, M, N)\n",
|
||||
" ):\n",
|
||||
" return Status.fail(\"All operands must be rank 3.\")\n",
|
||||
" return Status.success()\n"
|
||||
" return Status.success()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 10,
|
||||
"id": "edaf2cba",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -404,7 +438,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 11,
|
||||
"id": "cec5431d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -420,9 +454,11 @@
|
||||
"\n",
|
||||
"args = GemmArguments(A, B, out, accumulator_type=torch.float64)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def is_f64gemm_kernel(metadata):\n",
|
||||
" return metadata.kernel_class == F64GemmKernel\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"kernels = cutlass_api.get_kernels(args, metadata_filter=is_f64gemm_kernel)"
|
||||
]
|
||||
},
|
||||
@@ -437,10 +473,19 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 12,
|
||||
"id": "cdb92b5e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"F64GemmKernel_tile32x32_ttt\n",
|
||||
"F64GemmKernel_tile16x16_ttt\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(kernels[0].metadata.kernel_name)\n",
|
||||
"print(kernels[1].metadata.kernel_name)"
|
||||
@@ -456,7 +501,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 13,
|
||||
"id": "f5486244",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -478,17 +523,19 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 14,
|
||||
"id": "917c74e3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def my_filter(metadata):\n",
|
||||
" return (\n",
|
||||
" is_f64gemm_kernel(metadata) and\n",
|
||||
" isinstance(metadata.design, cutlass_api.metadata.BLASDesignMetadata) and\n",
|
||||
" metadata.design.tile_shape[0] == 256\n",
|
||||
" is_f64gemm_kernel(metadata)\n",
|
||||
" and isinstance(metadata.design, cutlass_api.metadata.BLASDesignMetadata)\n",
|
||||
" and metadata.design.tile_shape[0] == 256\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"kernels_ctam256 = cutlass_api.get_kernels(args, metadata_filter=my_filter)\n",
|
||||
"\n",
|
||||
"# No kernels should be found\n",
|
||||
@@ -539,8 +586,22 @@
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -17,12 +17,15 @@
|
||||
"This notebook focuses on the latter: techniques to minimize any overheads incurred from the CUTLASS API and underlying\n",
|
||||
"DSL runtimes.\n",
|
||||
"\n",
|
||||
"This notebook does not discuss techniques for improving device-side performance. A future notebook may cover this topic."
|
||||
"This notebook does not discuss techniques for improving device-side performance. A future notebook may cover this topic.\n",
|
||||
"\n",
|
||||
"**Note**: Latency measurements can vary from system to system. You may see different results on your system than shown\n",
|
||||
"in the pre-populated fields of this notebook."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": null,
|
||||
"id": "e3ca9e40",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -34,13 +37,15 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": null,
|
||||
"id": "efaac09c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if not (status := cutlass_api.utils.is_device_cc_supported({100, 103})):\n",
|
||||
" print(f\"This notebook requires a GPU with compute capability 100 or 103.\\n{status.error}\")\n",
|
||||
"if not (status := cutlass_api.utils.is_device_cc_supported({80, 89, 90, 100, 103})):\n",
|
||||
" print(\n",
|
||||
" f\"This notebook requires a GPU with compute capability >= 80.\\n{status.error}\"\n",
|
||||
" )\n",
|
||||
" import sys\n",
|
||||
"\n",
|
||||
" sys.exit(0)"
|
||||
@@ -61,7 +66,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": null,
|
||||
"id": "b8c44947",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -76,19 +81,34 @@
|
||||
"# We use different operands in each iteration. Though not particularly relevant for\n",
|
||||
"# host latency, this is a best practice when benchmarking GPU kernels to avoid\n",
|
||||
"# unrealistic caching effects.\n",
|
||||
"As = [torch.randint(-1, 2, (M, K), device=\"cuda\", dtype=torch.float16) for _ in range(total_iterations)]\n",
|
||||
"Bs = [torch.randint(-1, 2, (K, N), device=\"cuda\", dtype=torch.float16) for _ in range(total_iterations)]\n",
|
||||
"outs = [torch.empty((M, N), device=\"cuda\", dtype=torch.float16) for _ in range(total_iterations)]\n",
|
||||
"As = [\n",
|
||||
" torch.randint(-1, 2, (M, K), device=\"cuda\", dtype=torch.float16)\n",
|
||||
" for _ in range(total_iterations)\n",
|
||||
"]\n",
|
||||
"Bs = [\n",
|
||||
" torch.randint(-1, 2, (K, N), device=\"cuda\", dtype=torch.float16)\n",
|
||||
" for _ in range(total_iterations)\n",
|
||||
"]\n",
|
||||
"outs = [\n",
|
||||
" torch.empty((M, N), device=\"cuda\", dtype=torch.float16)\n",
|
||||
" for _ in range(total_iterations)\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"# Construct arguments outside of the benchmarking loop. We will later also consider\n",
|
||||
"# cases in which they are constructed inside the benchmarking loop.\n",
|
||||
"args = [cutlass_api.arguments.GemmArguments(A=As[i], B=Bs[i], out=outs[i], accumulator_type=torch.float16) for i in range(total_iterations)]\n",
|
||||
"args = [\n",
|
||||
" cutlass_api.arguments.GemmArguments(\n",
|
||||
" A=As[i], B=Bs[i], out=outs[i], accumulator_type=torch.float32\n",
|
||||
" )\n",
|
||||
" for i in range(total_iterations)\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"references = [(As[i] @ Bs[i]).to(outs[i].dtype) for i in range(total_iterations)]\n",
|
||||
"\n",
|
||||
"kernels = cutlass_api.get_kernels(args[0], cc=100)\n",
|
||||
"\n",
|
||||
"cc = cutlass_api.utils.device_cc()\n",
|
||||
"kernels = cutlass_api.get_kernels(args[0], cc=cc)\n",
|
||||
"assert len(kernels) > 0\n",
|
||||
"\n",
|
||||
"kernel = kernels[0]"
|
||||
]
|
||||
},
|
||||
@@ -102,14 +122,18 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": null,
|
||||
"id": "2472eafa",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def benchmark(label, code, warmup_it=warmup_iterations, profiling_it=profiling_iterations):\n",
|
||||
"def benchmark(\n",
|
||||
" label, code, warmup_it=warmup_iterations, profiling_it=profiling_iterations\n",
|
||||
"):\n",
|
||||
" total_it = warmup_it + profiling_it\n",
|
||||
" assert total_it <= total_iterations, f\"Benchmark-local iteration count must be less than or equal to total iterations: {total_it} > {total_iterations}\"\n",
|
||||
" assert total_it <= total_iterations, (\n",
|
||||
" f\"Benchmark-local iteration count must be less than or equal to total iterations: {total_it} > {total_iterations}\"\n",
|
||||
" )\n",
|
||||
" # warmup\n",
|
||||
" rets = [None] * total_it\n",
|
||||
" for i in range(warmup_it):\n",
|
||||
@@ -151,7 +175,7 @@
|
||||
"### Compile once, run many times\n",
|
||||
"The `kernel.run` method takes in an optional `compiled_artifact` argument of type\n",
|
||||
"`cutlass_api.artifact.CompiledArtifact`. When this argument is set, the kernel\n",
|
||||
"will directly use the precompiled function within `compiled_argument`. When\n",
|
||||
"will directly use the precompiled function within `compiled_artifact`. When\n",
|
||||
"it is not set, the call to `kernel.run` will JIT compile the kernel on each\n",
|
||||
"invocation.\n",
|
||||
"\n",
|
||||
@@ -160,25 +184,29 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": null,
|
||||
"id": "6de11f56",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"stream = torch.cuda.current_stream()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def no_compiled_artifact(i: int):\n",
|
||||
" return kernel.run(args[i], stream=stream)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Compile the kernel once, reuse for each iterations\n",
|
||||
"compiled_artifact = kernel.compile(args[0])\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def with_compiled_artifact(i: int):\n",
|
||||
" return kernel.run(args[i], stream=stream, compiled_artifact=compiled_artifact)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": null,
|
||||
"id": "350c9bd6",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -192,8 +220,12 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"time_no_artifact, _ = benchmark(f\"Without compiled artifact\", no_compiled_artifact, warmup_it=2, profiling_it=5)\n",
|
||||
"time_w_artifact, _ = benchmark(f\"With compiled artifact\", with_compiled_artifact, warmup_it=2, profiling_it=5)"
|
||||
"time_no_artifact, _ = benchmark(\n",
|
||||
" f\"Without compiled artifact\", no_compiled_artifact, warmup_it=2, profiling_it=5\n",
|
||||
")\n",
|
||||
"time_w_artifact, _ = benchmark(\n",
|
||||
" f\"With compiled artifact\", with_compiled_artifact, warmup_it=2, profiling_it=5\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -215,21 +247,32 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": null,
|
||||
"id": "5b93dfae",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def with_supports_check(i: int):\n",
|
||||
" return kernel.run(args[i], compiled_artifact=compiled_artifact, stream=stream, assume_supported_args=False)\n",
|
||||
" return kernel.run(\n",
|
||||
" args[i],\n",
|
||||
" compiled_artifact=compiled_artifact,\n",
|
||||
" stream=stream,\n",
|
||||
" assume_supported_args=False,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def without_supports_check(i: int):\n",
|
||||
" return kernel.run(args[i], compiled_artifact=compiled_artifact, stream=stream, assume_supported_args=True)"
|
||||
" return kernel.run(\n",
|
||||
" args[i],\n",
|
||||
" compiled_artifact=compiled_artifact,\n",
|
||||
" stream=stream,\n",
|
||||
" assume_supported_args=True,\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": null,
|
||||
"id": "b282f437",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -244,7 +287,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"time_w_supports, _ = benchmark(\"With supports check\", with_supports_check)\n",
|
||||
"time_w_supports, _ = benchmark(\"With supports check\", with_supports_check)\n",
|
||||
"time_wo_supports, _ = benchmark(\"Bypass supports check\", without_supports_check)\n",
|
||||
"print(f\"Speedup with skip supports: {time_w_supports / time_wo_supports:.2f}x\")"
|
||||
]
|
||||
@@ -262,14 +305,16 @@
|
||||
"id": "656d5e2c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"CUTLASS API supports [CUDA Graphs](https://developer.nvidia.com/blog/cuda-graphs/) usage with PyTorch as usual.\n",
|
||||
"[CUDA Graphs](https://developer.nvidia.com/blog/cuda-graphs/) allow a sequence of GPU operations to be defined as a dependency graph and then launched as a single unit, significantly reducing CPU launch overhead and enabling whole-graph optimizations.\n",
|
||||
"\n",
|
||||
"CUTLASS API supports CUDA Graphs usage with PyTorch as usual.\n",
|
||||
"\n",
|
||||
"The kernel compilation must happen outside the CUDA graph. Then, we create a graph using usual PyTorch idioms to launch a kernel several times on the graph's stream."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": null,
|
||||
"id": "e614509f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -279,6 +324,10 @@
|
||||
"# Create a CUDA Graph to run our compiled kernel N times\n",
|
||||
"g = torch.cuda.CUDAGraph()\n",
|
||||
"with torch.cuda.graph(g):\n",
|
||||
"\n",
|
||||
" ### NOTE! Kernel compilation must happen outside the graph\n",
|
||||
" ### kernel.compile(args)\n",
|
||||
"\n",
|
||||
" # Run N iterations of our compiled kernel on the current stream\n",
|
||||
" for i in range(num_launches):\n",
|
||||
" kernel.run(\n",
|
||||
@@ -286,10 +335,7 @@
|
||||
" compiled_artifact=compiled_artifact,\n",
|
||||
" stream=torch.cuda.current_stream(),\n",
|
||||
" assume_supported_args=True,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"# Zero the output so we don't refcheck stale results\n",
|
||||
"_ = outs[0].zero_()"
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -297,8 +343,12 @@
|
||||
"id": "8fc69c6e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Once captured, we can replay the graph. This will only replay the kernel launches placed on the CUDA stream.\n",
|
||||
"Any other prepratory work on the host and arguments passed in from python are cached during the capture."
|
||||
"This records/captures all the kernel launches to the CUDA Stream associated with the graph `g`, without actually launching them.\n",
|
||||
"Once captured, we can replay the graph.\n",
|
||||
"\n",
|
||||
"Note that graph replay will only replay the kernel launches placed on the graph's stream\n",
|
||||
"* During graph capture, we must be careful to capture to the correct stream (`torch.cuda.current_stream()` under the graph context)\n",
|
||||
"* Any other preparatory work on the host and arguments passed in from Python are cached during the capture. Changing them would require re-capturing the graph"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -324,7 +374,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"execution_count": null,
|
||||
"id": "45d4e739",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -348,12 +398,23 @@
|
||||
" assume_supported_args=True,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def with_cuda_graph(x: int):\n",
|
||||
" g.replay()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"time_wo_cuda_graph, _ = benchmark(f\"{num_launches} launches without CUDA Graph\", without_cuda_graph, warmup_it=0, profiling_it=1)\n",
|
||||
"time_w_cuda_graph, _ = benchmark(f\"{num_launches} launches with CUDA Graph\", with_cuda_graph, warmup_it=0, profiling_it=1)\n",
|
||||
"time_wo_cuda_graph, _ = benchmark(\n",
|
||||
" f\"{num_launches} launches without CUDA Graph\",\n",
|
||||
" without_cuda_graph,\n",
|
||||
" warmup_it=0,\n",
|
||||
" profiling_it=1,\n",
|
||||
")\n",
|
||||
"time_w_cuda_graph, _ = benchmark(\n",
|
||||
" f\"{num_launches} launches with CUDA Graph\",\n",
|
||||
" with_cuda_graph,\n",
|
||||
" warmup_it=0,\n",
|
||||
" profiling_it=1,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(f\"Speedup with CUDA Graph: {time_wo_cuda_graph / time_w_cuda_graph:.2f}x\")"
|
||||
]
|
||||
@@ -371,8 +432,8 @@
|
||||
"id": "ee7f9fd2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"When applicable, CUTLASS API uses [Apache TVM FFI](https://tvm.apache.org/ffi/) under the hood for invoking compiled DSL kernels from Python.\n",
|
||||
"Apache TVM FFI is an open ABI and FFI for machine learning systems.\n",
|
||||
"[Apache TVM FFI](https://tvm.apache.org/ffi/) is an open ABI and FFI for machine learning systems.\n",
|
||||
"When available, CUTLASS API uses Apache TVM-FFI under the hood as its interface for invoking compiled DSL kernels from Python.\n",
|
||||
"\n",
|
||||
"TVM FFI is enabled by default in CUTLASS API, and is recommended for best performance."
|
||||
]
|
||||
@@ -413,7 +474,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"execution_count": null,
|
||||
"id": "e8f56be3",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -428,10 +489,15 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"original_use_tvm_ffi = cutlass_api.config.GlobalOptions().use_tvm_ffi\n",
|
||||
"\n",
|
||||
"cutlass_api.config.GlobalOptions().use_tvm_ffi = True\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def run_iteration(i):\n",
|
||||
" args = cutlass_api.arguments.GemmArguments(A=As[i], B=Bs[i], out=outs[i], accumulator_type=torch.float16)\n",
|
||||
" args = cutlass_api.arguments.GemmArguments(\n",
|
||||
" A=As[i], B=Bs[i], out=outs[i], accumulator_type=torch.float16\n",
|
||||
" )\n",
|
||||
" return kernel.run(\n",
|
||||
" args,\n",
|
||||
" compiled_artifact=compiled_artifact,\n",
|
||||
@@ -439,18 +505,35 @@
|
||||
" assume_supported_args=True,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def create_arguments(i: int):\n",
|
||||
" return cutlass_api.arguments.GemmArguments(A=As[i], B=Bs[i], out=outs[i], accumulator_type=torch.float16)\n",
|
||||
" return cutlass_api.arguments.GemmArguments(\n",
|
||||
" A=As[i], B=Bs[i], out=outs[i], accumulator_type=torch.float16\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"args_creation_on, args = benchmark(\"[TVM-FFI ON ] Create args\", create_arguments)\n",
|
||||
"compilation_on, compiled = benchmark(\"[TVM-FFI ON ] Compile kernel\", lambda i: kernel.compile(args[i]), warmup_it=2, profiling_it=5)\n",
|
||||
"compilation_on, compiled = benchmark(\n",
|
||||
" \"[TVM-FFI ON ] Compile kernel\",\n",
|
||||
" lambda i: kernel.compile(args[i]),\n",
|
||||
" warmup_it=2,\n",
|
||||
" profiling_it=5,\n",
|
||||
")\n",
|
||||
"compiled_artifact = compiled[0]\n",
|
||||
"run_on, _ = benchmark(\"[TVM-FFI ON ] Run kernel\", lambda i: kernel.run(args[i], compiled_artifact=compiled_artifact, assume_supported_args=True, stream=stream))"
|
||||
"run_on, _ = benchmark(\n",
|
||||
" \"[TVM-FFI ON ] Run kernel\",\n",
|
||||
" lambda i: kernel.run(\n",
|
||||
" args[i],\n",
|
||||
" compiled_artifact=compiled_artifact,\n",
|
||||
" assume_supported_args=True,\n",
|
||||
" stream=stream,\n",
|
||||
" ),\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"execution_count": null,
|
||||
"id": "5a4c2db4",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -467,9 +550,25 @@
|
||||
"source": [
|
||||
"cutlass_api.config.GlobalOptions().use_tvm_ffi = False\n",
|
||||
"args_creation_off, args = benchmark(\"[TVM-FFI OFF ] Create args\", create_arguments)\n",
|
||||
"compilation_off, compiled = benchmark(\"[TVM-FFI OFF ] Compile kernel\", lambda i: kernel.compile(args[i]), warmup_it=2, profiling_it=5)\n",
|
||||
"compilation_off, compiled = benchmark(\n",
|
||||
" \"[TVM-FFI OFF ] Compile kernel\",\n",
|
||||
" lambda i: kernel.compile(args[i]),\n",
|
||||
" warmup_it=2,\n",
|
||||
" profiling_it=5,\n",
|
||||
")\n",
|
||||
"compiled_artifact = compiled[0]\n",
|
||||
"run_off, _ = benchmark(\"[TVM-FFI OFF ] Run kernel\", lambda i: kernel.run(args[i], compiled_artifact=compiled_artifact, assume_supported_args=True, stream=stream))"
|
||||
"run_off, _ = benchmark(\n",
|
||||
" \"[TVM-FFI OFF ] Run kernel\",\n",
|
||||
" lambda i: kernel.run(\n",
|
||||
" args[i],\n",
|
||||
" compiled_artifact=compiled_artifact,\n",
|
||||
" assume_supported_args=True,\n",
|
||||
" stream=stream,\n",
|
||||
" ),\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Restore original setting\n",
|
||||
"cutlass_api.config.GlobalOptions().use_tvm_ffi = original_use_tvm_ffi"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
166
python/cutlass_api/examples/004_fake_tensors.ipynb
Normal file
166
python/cutlass_api/examples/004_fake_tensors.ipynb
Normal file
@@ -0,0 +1,166 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4620d513",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Using fake tensors with the CUTLASS API\n",
|
||||
"Fake tensors (e.g., [torch's FakeTensor](https://docs.pytorch.org/docs/2.8/torch.compiler_fake_tensor.html))\n",
|
||||
"are useful for describing the properties of a tensor without actually allocating backing data.\n",
|
||||
"\n",
|
||||
"This example shows how fake tensors can be used within the CUTLASS API\n",
|
||||
"for discovering and compiling a GEMM kernel."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "d231b32e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"import cutlass_api\n",
|
||||
"\n",
|
||||
"torch.manual_seed(2025)\n",
|
||||
"\n",
|
||||
"if not (status := cutlass_api.utils.is_device_cc_supported({80, 89, 90, 100, 103})):\n",
|
||||
" print(f\"This notebook requires a GPU with compute capability >= 80.\\n{status.error}\")\n",
|
||||
" import sys\n",
|
||||
" sys.exit(0)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f7af2d90",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We first set up operands `A`, `B`, and `out` in torch's `FakeTensorMode`.\n",
|
||||
"These will have all the properties needed for CUTLASS API to construct\n",
|
||||
"the internal representations of tensors used for discovering and compiling\n",
|
||||
"kernels."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "9426b66f",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"FakeTensor(..., device='cuda:0', size=(128, 512), dtype=torch.float16)\n",
|
||||
"FakeTensor(..., device='cuda:0', size=(512, 256), dtype=torch.float16)\n",
|
||||
"FakeTensor(..., device='cuda:0', size=(128, 256), dtype=torch.float16)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"M, N, K = 128, 256, 512\n",
|
||||
"\n",
|
||||
"with torch._subclasses.fake_tensor.FakeTensorMode():\n",
|
||||
" A_fake = torch.randn(M, K, device=\"cuda\", dtype=torch.float16)\n",
|
||||
" B_fake = torch.randn(K, N, device=\"cuda\", dtype=torch.float16)\n",
|
||||
" out_fake = torch.empty(M, N, device=\"cuda\", dtype=torch.float16)\n",
|
||||
"\n",
|
||||
"print(A_fake)\n",
|
||||
"print(B_fake)\n",
|
||||
"print(out_fake)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4f540c78",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can now use these fake tensors to create `GemmArguments`, and use\n",
|
||||
"these to discover and compile a compatible kernel. Note that the same APIs are\n",
|
||||
"used in creating `GemmArguments` as would be used if using\n",
|
||||
"\"real\" tensors."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "e32b700d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"args_fake = cutlass_api.arguments.GemmArguments(\n",
|
||||
" A_fake, B_fake, out_fake, accumulator_type=torch.float32)\n",
|
||||
"\n",
|
||||
"cc = cutlass_api.utils.device_cc()\n",
|
||||
"kernels = cutlass_api.get_kernels(args_fake, cc=cc)\n",
|
||||
"assert len(kernels) > 0\n",
|
||||
"\n",
|
||||
"kernel = kernels[0]\n",
|
||||
"compiled_artifact = kernel.compile(args_fake)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "07fff511",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The `kernel` and `compiled_artifact` discovered using fake tensors\n",
|
||||
"above can now used for running the kernel using real tensors."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "b3034bf1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Create real tensors\n",
|
||||
"A_real = torch.randn(M, K, device=\"cuda\", dtype=torch.float16)\n",
|
||||
"B_real = torch.randn(K, N, device=\"cuda\", dtype=torch.float16)\n",
|
||||
"out_real = torch.empty(M, N, device=\"cuda\", dtype=torch.float16)\n",
|
||||
"\n",
|
||||
"args_real = cutlass_api.arguments.GemmArguments(\n",
|
||||
" A_real, B_real, out_real, accumulator_type=torch.float32)\n",
|
||||
"\n",
|
||||
"# Run the kernel using the compiled_artifact from resulting\n",
|
||||
"# from compiling with fake tensors.\n",
|
||||
"kernel.run(args_real, compiled_artifact)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "09871eca",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ref = A_real @ B_real\n",
|
||||
"torch.testing.assert_close(out_real, ref)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -22,6 +22,8 @@ torch = [
|
||||
]
|
||||
test = [
|
||||
"jupyter",
|
||||
"nbconvert",
|
||||
"nbformat",
|
||||
"pytest",
|
||||
"cutlass_api[torch]",
|
||||
]
|
||||
|
||||
72
python/cutlass_api/test/conftest.py
Normal file
72
python/cutlass_api/test/conftest.py
Normal file
@@ -0,0 +1,72 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
import cutlass_api
|
||||
from cutlass_api.config import GlobalOptions
|
||||
|
||||
#
|
||||
# Before each test, save the GlobalOptions dict
|
||||
# After each test, restore the GlobalOptions dict
|
||||
#
|
||||
|
||||
global_options = {}
|
||||
|
||||
|
||||
def setup_function():
|
||||
global global_options
|
||||
GlobalOptions().save(global_options)
|
||||
|
||||
|
||||
def teardown_function():
|
||||
global global_options
|
||||
GlobalOptions().restore(global_options)
|
||||
|
||||
|
||||
#
|
||||
# Fixtures for toggling the TVM FFI global option and forcing its enablement or disablement
|
||||
#
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
params=[True, False], ids=["use_tvm_ffi=True", f"use_tvm_ffi=False"], autouse=False
|
||||
)
|
||||
def fixture_toggle_tvm_ffi(request):
|
||||
GlobalOptions().use_tvm_ffi = request.param
|
||||
|
||||
|
||||
@pytest.fixture(autouse=False)
|
||||
def fixture_enable_tvm_ffi(request):
|
||||
GlobalOptions().use_tvm_ffi = True
|
||||
|
||||
|
||||
@pytest.fixture(autouse=False)
|
||||
def fixture_disable_tvm_ffi(request):
|
||||
GlobalOptions().use_tvm_ffi = False
|
||||
@@ -68,6 +68,7 @@ def test_gemm_sm100(
|
||||
c_dtype: torch.dtype,
|
||||
accumulator_type: torch.dtype,
|
||||
n_iterations: int,
|
||||
fixture_toggle_tvm_ffi,
|
||||
):
|
||||
A = torch.randint(-1, 2, (M, K), device="cuda").to(ab_dtype)
|
||||
B = torch.randint(-1, 2, (K, N), device="cuda").to(ab_dtype)
|
||||
|
||||
@@ -32,6 +32,7 @@ import torch
|
||||
|
||||
import cutlass_api
|
||||
from cutlass_api.utils import device_cc
|
||||
from cutlass_api.config import GlobalOptions
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -48,11 +49,7 @@ from cutlass_api.utils import device_cc
|
||||
torch.float16,
|
||||
],
|
||||
)
|
||||
def test_elementwise_add(
|
||||
M: int,
|
||||
N: int,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
def test_elementwise_add(M: int, N: int, dtype: torch.dtype, fixture_toggle_tvm_ffi):
|
||||
A = torch.randint(-1, 2, (M, N), device="cuda", dtype=dtype)
|
||||
B = torch.randint(-1, 2, (M, N), device="cuda", dtype=dtype)
|
||||
D = torch.empty((M, N), device="cuda", dtype=dtype)
|
||||
|
||||
@@ -38,7 +38,6 @@ import torch
|
||||
import cutlass
|
||||
|
||||
import cutlass_api
|
||||
from cutlass_api.config import GlobalOptions
|
||||
from cutlass_api.utils import is_device_cc_supported
|
||||
|
||||
torch.manual_seed(2025)
|
||||
@@ -62,8 +61,17 @@ logger = logging.getLogger(__name__)
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"use_tvm_ffi",
|
||||
[True, False],
|
||||
"a_major, b_major, c_major",
|
||||
[
|
||||
("k", "k", "n"),
|
||||
("k", "k", "m"),
|
||||
("k", "n", "m"),
|
||||
("k", "n", "n"),
|
||||
("m", "k", "n"),
|
||||
("m", "k", "m"),
|
||||
("m", "n", "m"),
|
||||
("m", "n", "n"),
|
||||
],
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not is_device_cc_supported({100})
|
||||
@@ -71,20 +79,29 @@ logger = logging.getLogger(__name__)
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
|
||||
)
|
||||
def test_gemm_sm100(
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
L: int,
|
||||
M: int, N: int, K: int, L: int,
|
||||
ab_dtype: torch.dtype,
|
||||
c_dtype: torch.dtype,
|
||||
accumulator_type: torch.dtype,
|
||||
use_tvm_ffi: bool,
|
||||
a_major: str,
|
||||
b_major: str,
|
||||
c_major: str,
|
||||
fixture_toggle_tvm_ffi,
|
||||
):
|
||||
A = torch.randint(-1, 2, (L, M, K), device="cuda", dtype=ab_dtype)
|
||||
B = torch.randint(-1, 2, (L, K, N), device="cuda", dtype=ab_dtype)
|
||||
D = torch.empty((L, M, N), device="cuda", dtype=c_dtype)
|
||||
if a_major == "k":
|
||||
A = torch.randint(-1, 2, (L, M, K), device="cuda", dtype=ab_dtype)
|
||||
else:
|
||||
A = torch.randint(-1, 2, (L, K, M), device="cuda", dtype=ab_dtype).permute(0, 2, 1)
|
||||
|
||||
GlobalOptions().use_tvm_ffi = use_tvm_ffi
|
||||
if b_major == "n":
|
||||
B = torch.randint(-1, 2, (L, K, N), device="cuda", dtype=ab_dtype)
|
||||
else:
|
||||
B = torch.randint(-1, 2, (L, N, K), device="cuda", dtype=ab_dtype).permute(0, 2, 1)
|
||||
|
||||
if c_major == "n":
|
||||
D = torch.empty((L, M, N), device="cuda", dtype=c_dtype)
|
||||
else:
|
||||
D = torch.empty((L, N, M), device="cuda", dtype=c_dtype).permute(0, 2, 1)
|
||||
|
||||
args = cutlass_api.arguments.GemmArguments(A, B, D, accumulator_type)
|
||||
|
||||
@@ -100,16 +117,12 @@ def test_gemm_sm100(
|
||||
torch.testing.assert_close(D, reference.to(D.dtype))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"use_tvm_ffi",
|
||||
[True, False],
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not is_device_cc_supported({100})
|
||||
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_100f"]),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
|
||||
)
|
||||
def test_gemm_sm100_2d(use_tvm_ffi: bool):
|
||||
def test_gemm_sm100_2d(fixture_toggle_tvm_ffi):
|
||||
ab_dtype = torch.float16
|
||||
c_dtype = torch.float16
|
||||
accumulator_type = torch.float32
|
||||
@@ -120,8 +133,6 @@ def test_gemm_sm100_2d(use_tvm_ffi: bool):
|
||||
B = torch.randint(-1, 2, (K, N), device="cuda", dtype=ab_dtype)
|
||||
D = torch.empty((M, N), device="cuda", dtype=c_dtype)
|
||||
|
||||
GlobalOptions().use_tvm_ffi = use_tvm_ffi
|
||||
|
||||
args = cutlass_api.arguments.GemmArguments(A, B, D, accumulator_type)
|
||||
|
||||
kernels = cutlass_api.get_kernels(args, cc=100)
|
||||
@@ -141,7 +152,7 @@ def test_gemm_sm100_2d(use_tvm_ffi: bool):
|
||||
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a"]),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a",
|
||||
)
|
||||
def test_gemm_sm100_int8():
|
||||
def test_gemm_sm100_int8(fixture_toggle_tvm_ffi):
|
||||
M, N, K = 256, 512, 128
|
||||
A = torch.randint(-1, 2, (M, K), device="cuda", dtype=torch.int8)
|
||||
B = torch.randint(-1, 2, (K, N), device="cuda", dtype=torch.int8)
|
||||
@@ -159,27 +170,48 @@ def test_gemm_sm100_int8():
|
||||
reference = torch._int_mm(A, B)
|
||||
torch.testing.assert_close(D, reference.to(D.dtype))
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
find_spec("tvm_ffi") is None,
|
||||
reason="FP8 currently requires TVM FFI to be installed",
|
||||
@pytest.mark.parametrize(
|
||||
"problem_size",
|
||||
[
|
||||
(256, 512, 128),
|
||||
(256, 512, 128, 1),
|
||||
(256, 512, 128, 2),
|
||||
]
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not is_device_cc_supported({100})
|
||||
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a"]),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a",
|
||||
)
|
||||
def test_gemm_sm100_fp8():
|
||||
# Currently, FP8 is only supported via TVM FFI.
|
||||
GlobalOptions().use_tvm_ffi = True
|
||||
def test_gemm_sm100_fp8(
|
||||
problem_size: tuple[int, ...],
|
||||
fixture_enable_tvm_ffi, # FP8 currently requires TVM FFI to be installed
|
||||
):
|
||||
M, N, K = problem_size[:3]
|
||||
|
||||
M, N, K = 256, 512, 128
|
||||
# Create torch fp8 tensors for A and B
|
||||
A = torch.randint(-1, 2, (M, K), device="cuda").to(torch.float8_e4m3fn)
|
||||
D = torch.empty((M, N), device="cuda", dtype=torch.float32)
|
||||
identity_scale = torch.ones(1, device="cuda", dtype=torch.float32)
|
||||
|
||||
# Transpose B because torch._scaled_mm expects B to be column-major
|
||||
B = torch.randint(-1, 2, (N, K), device="cuda").to(torch.float8_e4m3fn).T
|
||||
if len(problem_size) == 3:
|
||||
A = torch.randint(-1, 2, (M, K), device="cuda").to(torch.float8_e4m3fn)
|
||||
D = torch.empty((M, N), device="cuda", dtype=torch.float32)
|
||||
|
||||
# Transpose B because torch._scaled_mm expects B to be column-major
|
||||
B = torch.randint(-1, 2, (N, K), device="cuda").to(torch.float8_e4m3fn).T
|
||||
|
||||
reference = torch._scaled_mm(A, B, identity_scale, identity_scale, out_dtype=torch.float32)
|
||||
else:
|
||||
L = problem_size[3]
|
||||
A = torch.randint(-1, 2, (L, M, K), device="cuda").to(torch.float8_e4m3fn)
|
||||
D = torch.empty((L, M, N), device="cuda", dtype=torch.float32)
|
||||
|
||||
# Transpose B because torch._scaled_mm expects B to be column-major
|
||||
B = torch.randint(-1, 2, (L, N, K), device="cuda").to(torch.float8_e4m3fn).permute(0, 2, 1)
|
||||
|
||||
reference = torch.empty((L, M, N), device="cuda", dtype=torch.float32)
|
||||
for l in range(L):
|
||||
reference[l, :, :] = torch._scaled_mm(
|
||||
A[l, :, :], B[l, :, :], identity_scale, identity_scale, out_dtype=torch.float32
|
||||
)
|
||||
|
||||
args = cutlass_api.arguments.GemmArguments(A, B, D, accumulator_type=torch.float32)
|
||||
kernels = cutlass_api.get_kernels(args, cc=100)
|
||||
@@ -190,10 +222,6 @@ def test_gemm_sm100_fp8():
|
||||
compiled_artifact = kernel.compile(args)
|
||||
kernel.run(args, compiled_artifact=compiled_artifact, assume_supported_args=True)
|
||||
|
||||
identity_scale = torch.ones(1, device="cuda", dtype=torch.float32)
|
||||
reference = torch._scaled_mm(
|
||||
A, B, identity_scale, identity_scale, out_dtype=torch.float32
|
||||
)
|
||||
torch.testing.assert_close(D, reference)
|
||||
|
||||
|
||||
@@ -207,7 +235,7 @@ def test_no_gemms_available():
|
||||
args = cutlass_api.arguments.GemmArguments(A, B, D, accumulator_type=torch.float32)
|
||||
kernels = cutlass_api.get_kernels(args, cc=70)
|
||||
|
||||
# There are currenlty no kernels available for compute capability 70.
|
||||
# There are currently no kernels available for compute capability 70.
|
||||
assert len(kernels) == 0
|
||||
|
||||
|
||||
|
||||
@@ -33,8 +33,8 @@ import pytest
|
||||
import torch
|
||||
|
||||
import cutlass_api
|
||||
from cutlass_api.utils import device_cc
|
||||
from cutlass_api.config import GlobalOptions
|
||||
from cutlass_api.utils import is_device_cc_supported
|
||||
|
||||
torch.manual_seed(2025)
|
||||
|
||||
@@ -62,7 +62,7 @@ def base_data_types():
|
||||
|
||||
|
||||
def supports_sm100af():
|
||||
return device_cc() == 100 and (
|
||||
return is_device_cc_supported({100}) and (
|
||||
os.getenv("CUTE_DSL_ARCH", "") in ["", "sm_100a", "sm_100f"]
|
||||
)
|
||||
|
||||
@@ -485,9 +485,7 @@ def test_gemm_fusion_return_acc():
|
||||
|
||||
epi_str = "def epi(accum, C): D = relu(accum) + C; return D, accum"
|
||||
|
||||
epi_args = cutlass_api.arguments.EpilogueArguments(
|
||||
epi_str, C=C, D=D, accum=accum
|
||||
)
|
||||
epi_args = cutlass_api.arguments.EpilogueArguments(epi_str, C=C, D=D, accum=accum)
|
||||
|
||||
args = cutlass_api.arguments.GemmArguments(
|
||||
A=A, B=B, out=D, accumulator_type=accumulator_type, epilogue=epi_args
|
||||
@@ -628,19 +626,13 @@ def test_gemm_fusion_matmul_input_as_aux():
|
||||
@pytest.mark.parametrize(
|
||||
"ab_dtype, c_dtype, d_dtype, accumulator_type", base_data_types()
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"use_tvm_ffi",
|
||||
[True, False],
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not supports_sm100af(),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
|
||||
)
|
||||
def test_gemm_alpha_beta(
|
||||
M, N, K, L, ab_dtype, c_dtype, d_dtype, accumulator_type, use_tvm_ffi
|
||||
M, N, K, L, ab_dtype, c_dtype, d_dtype, accumulator_type, fixture_toggle_tvm_ffi
|
||||
):
|
||||
GlobalOptions().use_tvm_ffi = use_tvm_ffi
|
||||
|
||||
A = torch.randint(-1, 2, (L, M, K), device="cuda", dtype=ab_dtype)
|
||||
B = torch.randint(-1, 2, (L, K, N), device="cuda", dtype=ab_dtype)
|
||||
C = torch.randint(-1, 2, (L, M, N), device="cuda", dtype=c_dtype)
|
||||
@@ -680,7 +672,7 @@ def test_gemm_alpha_beta(
|
||||
not supports_sm100af(),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
|
||||
)
|
||||
def test_gemm_big_epi():
|
||||
def test_gemm_big_epi(fixture_toggle_tvm_ffi):
|
||||
M, N, K, L = 256, 512, 128, 2
|
||||
ab_dtype = torch.float16
|
||||
c_dtype = torch.float32
|
||||
@@ -809,7 +801,7 @@ def test_gemm_big_epi():
|
||||
not supports_sm100af(),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
|
||||
)
|
||||
def test_gemm_fusion_not_available():
|
||||
def test_gemm_fusion_not_available(fixture_toggle_tvm_ffi):
|
||||
M = 256
|
||||
N = 512
|
||||
K = 1024
|
||||
|
||||
211
python/cutlass_api/test/integration/test_gemm_sm80.py
Normal file
211
python/cutlass_api/test/integration/test_gemm_sm80.py
Normal file
@@ -0,0 +1,211 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
import logging
|
||||
import os
|
||||
from importlib.util import find_spec
|
||||
from pprint import pformat
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import cutlass
|
||||
|
||||
import cutlass_api
|
||||
from cutlass_api.utils import is_device_cc_supported
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Set to None to test all kernels, otherwise set to the number of kernels to test
|
||||
_MAX_KERNELS_TO_TEST = 1
|
||||
|
||||
# ab_dtype, c_dtype, accumulator_type
|
||||
# NOTE: The current Sm80TensorOpGemmKernel implementation only generates kernels
|
||||
# for Float32 accumulators (see _metadata_operand_combinations in sm80_tensorop_gemm.py),
|
||||
# so we restrict tests to those combinations for now.
|
||||
_DTYPES = [
|
||||
(torch.float16, torch.float16, torch.float32),
|
||||
(torch.bfloat16, torch.bfloat16, torch.float32),
|
||||
]
|
||||
|
||||
# M, N, K, L and L is optional
|
||||
def get_sizes_mnkl(level: int = 0) -> list[tuple[int, int, int, int]]:
|
||||
problem_sizes = []
|
||||
|
||||
if level == 0:
|
||||
problem_sizes.append((128, 128, 64, 1))
|
||||
problem_sizes.append((256, 256, 128, 1))
|
||||
problem_sizes.append((512, 256, 256, 1))
|
||||
problem_sizes.append((256, 512, 384, 2))
|
||||
else:
|
||||
raise ValueError(f"Invalid level: {level}")
|
||||
|
||||
return problem_sizes
|
||||
|
||||
# Layout combinations to test: (A_layout, B_layout, C_layout)
|
||||
# 't' = transposed/row-major (last dim contiguous, stride=1)
|
||||
# 'n' = normal/column-major (middle dim contiguous, stride=1)
|
||||
def get_layouts() -> list[tuple[str, str, str]]:
|
||||
return [
|
||||
("t", "t", "t"), # All row-major
|
||||
("t", "n", "t"), # A row-major, B column-major, C row-major
|
||||
("n", "t", "t"), # A column-major, B row-major, C row-major
|
||||
("n", "n", "t"), # A,B column-major, C row-major
|
||||
("t", "t", "n"),
|
||||
("t", "n", "n"),
|
||||
("n", "t", "n"),
|
||||
("n", "n", "n"),
|
||||
]
|
||||
|
||||
def create_layout_tensor(shape: tuple, layout: str, dtype: torch.dtype, device: str = "cuda") -> torch.Tensor:
|
||||
"""Create a tensor with specified layout ensuring proper alignment.
|
||||
|
||||
For layout 't' (row-major): last dimension is contiguous
|
||||
For layout 'n' (column-major): middle dimension is contiguous
|
||||
"""
|
||||
L, dim1, dim2 = shape
|
||||
|
||||
if layout == "t":
|
||||
# Row-major: last dimension contiguous
|
||||
return torch.randint(-1, 2, (L, dim1, dim2), device=device, dtype=dtype)
|
||||
else:
|
||||
# Column-major: middle dimension contiguous
|
||||
tensor_temp = torch.randint(-1, 2, (L, dim2, dim1), device=device, dtype=dtype)
|
||||
return tensor_temp.transpose(1, 2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"M, N, K, L, ab_dtype, c_dtype, accumulator_type",
|
||||
[(*sizes, *dtypes) for sizes in get_sizes_mnkl() for dtypes in _DTYPES],
|
||||
)
|
||||
def test_gemm_sm80(
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
L: int,
|
||||
ab_dtype: torch.dtype,
|
||||
c_dtype: torch.dtype,
|
||||
accumulator_type: torch.dtype,
|
||||
fixture_toggle_tvm_ffi,
|
||||
):
|
||||
A = torch.randint(-1, 2, (L, M, K), device="cuda", dtype=ab_dtype)
|
||||
B = torch.randint(-1, 2, (L, K, N), device="cuda", dtype=ab_dtype)
|
||||
D = torch.empty((L, M, N), device="cuda", dtype=c_dtype)
|
||||
|
||||
args = cutlass_api.arguments.GemmArguments(A, B, D, accumulator_type)
|
||||
kernels = cutlass_api.get_kernels(args, cc=80)
|
||||
|
||||
assert len(kernels) > 0, "No kernels returned for the given configuration"
|
||||
|
||||
max_kernels = len(kernels) if _MAX_KERNELS_TO_TEST is None else _MAX_KERNELS_TO_TEST
|
||||
kernels_to_test = kernels[:max_kernels]
|
||||
|
||||
# Compute reference using PyTorch's GPU matmul
|
||||
reference = (A.to(torch.float32) @ B.to(torch.float32)).to(c_dtype)
|
||||
|
||||
# Test the selected kernels
|
||||
for idx, kernel in enumerate(kernels_to_test):
|
||||
logger.debug(f"Testing kernel {idx+1}/{len(kernels_to_test)}: {kernel.metadata.kernel_name}")
|
||||
logger.debug(f"Kernel metadata:\n{pformat(kernel.metadata)}")
|
||||
|
||||
# Create fresh output tensor for each kernel to avoid interference
|
||||
D_test = torch.empty((L, M, N), device="cuda", dtype=c_dtype)
|
||||
args_test = cutlass_api.arguments.GemmArguments(A, B, D_test, accumulator_type)
|
||||
|
||||
try:
|
||||
kernel.run(args_test)
|
||||
# Verify correctness against PyTorch reference
|
||||
torch.testing.assert_close(D_test, reference, atol=1e-2, rtol=1e-2)
|
||||
logger.debug(f"Kernel {idx+1}/{len(kernels_to_test)} passed")
|
||||
except Exception as e:
|
||||
pytest.fail(
|
||||
f"Kernel {idx+1}/{len(kernels_to_test)} ({kernel.metadata.kernel_name}) failed: {e}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"M, N, K, L, ab_dtype, c_dtype, accumulator_type, layouts",
|
||||
[
|
||||
(*sizes, *dtypes, layouts)
|
||||
for sizes in [(248, 264, 248, 1)] # Test with one size to keep test time reasonable
|
||||
for dtypes in _DTYPES
|
||||
for layouts in get_layouts()
|
||||
],
|
||||
)
|
||||
def test_gemm_sm80_layouts(
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
L: int,
|
||||
ab_dtype: torch.dtype,
|
||||
c_dtype: torch.dtype,
|
||||
accumulator_type: torch.dtype,
|
||||
layouts: tuple[str, str, str],
|
||||
fixture_disable_tvm_ffi,
|
||||
):
|
||||
"""Test different tensor layouts (row-major vs column-major).
|
||||
|
||||
Note: Currently only tests with use_tvm_ffi=False because TVM-FFI has strict
|
||||
stride alignment requirements that are violated by transposed column-major layouts.
|
||||
"""
|
||||
layout_A, layout_B, layout_C = layouts
|
||||
|
||||
# Create tensors with specified layouts
|
||||
A = create_layout_tensor((L, M, K), layout_A, ab_dtype)
|
||||
B = create_layout_tensor((L, K, N), layout_B, ab_dtype)
|
||||
D = create_layout_tensor((L, M, N), layout_C, c_dtype)
|
||||
|
||||
args = cutlass_api.arguments.GemmArguments(A, B, D, accumulator_type)
|
||||
kernels = cutlass_api.get_kernels(args, cc=80)
|
||||
|
||||
assert len(kernels) > 0, f"No kernels returned for layout {layout_A}{layout_B}{layout_C}"
|
||||
|
||||
max_kernels = len(kernels) if _MAX_KERNELS_TO_TEST is None else _MAX_KERNELS_TO_TEST
|
||||
kernels_to_test = kernels[:max_kernels]
|
||||
|
||||
# Compute reference using PyTorch's GPU matmul
|
||||
reference = (A.to(torch.float32) @ B.to(torch.float32)).to(c_dtype)
|
||||
|
||||
# Test the selected kernels
|
||||
for idx, kernel in enumerate(kernels_to_test):
|
||||
logger.debug(f"Testing kernel {idx+1}/{len(kernels_to_test)}: {kernel.metadata.kernel_name}")
|
||||
|
||||
# Create fresh output tensor for each kernel with same layout
|
||||
D_test = create_layout_tensor((L, M, N), layout_C, c_dtype)
|
||||
args_test = cutlass_api.arguments.GemmArguments(A, B, D_test, accumulator_type)
|
||||
|
||||
try:
|
||||
kernel.run(args_test)
|
||||
# Verify correctness against PyTorch reference
|
||||
torch.testing.assert_close(D_test, reference, atol=1e-2, rtol=1e-2)
|
||||
logger.debug(f"Kernel {idx+1}/{len(kernels_to_test)} passed for layout {layout_A}{layout_B}{layout_C}")
|
||||
except Exception as e:
|
||||
pytest.fail(
|
||||
f"Kernel {idx+1}/{len(kernels_to_test)} ({kernel.metadata.kernel_name}) failed for layout {layout_A}{layout_B}{layout_C}: {e}"
|
||||
)
|
||||
@@ -27,6 +27,7 @@
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -36,10 +37,11 @@ import cutlass_api
|
||||
@pytest.mark.parametrize(
|
||||
"notebook_name, supported_ccs",
|
||||
[
|
||||
("000_gemm.ipynb", [100, 103]),
|
||||
("000_gemm.ipynb", [80, 89, 90,100, 103]),
|
||||
("001_gemm_with_fused_epilogue.ipynb", [100, 103]),
|
||||
("002_bring_your_own_kernel.ipynb", [80, 89, 90, 100, 103, 120, 121]),
|
||||
("003_host_latency_best_practices.ipynb", [100, 103]),
|
||||
("003_host_latency_best_practices.ipynb", [80, 89, 90, 100, 103]),
|
||||
("004_fake_tensors.ipynb", [80, 89, 90, 100, 103]),
|
||||
],
|
||||
)
|
||||
def test_notebooks(notebook_name, supported_ccs):
|
||||
@@ -74,12 +76,21 @@ def test_notebooks(notebook_name, supported_ccs):
|
||||
notebook_dir = os.path.join(os.path.dirname(__file__), "..", "..", "examples")
|
||||
full_notebook_path = os.path.join(notebook_dir, notebook_name)
|
||||
|
||||
import nbconvert
|
||||
import subprocess
|
||||
import sys
|
||||
from nbconvert.preprocessors import ExecutePreprocessor
|
||||
import nbformat
|
||||
|
||||
with open(full_notebook_path, "r") as file:
|
||||
# Register the current Python interpreter as the python3 kernel
|
||||
subprocess.run(
|
||||
[sys.executable, "-m", "ipykernel", "install", "--user", "--name", "python3"],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
)
|
||||
|
||||
with Path(full_notebook_path).open() as file:
|
||||
notebook = nbformat.read(file, as_version=4)
|
||||
ep = nbconvert.preprocessors.ExecutePreprocessor(timeout=600, kernel_name="python3")
|
||||
ep = ExecutePreprocessor(timeout=600, kernel_name="python3")
|
||||
|
||||
# Execute the notebook. This call will error out on any assertions or errors in the
|
||||
# notebook itself. Allow these to propagate up so the test will fail on notebook failure.
|
||||
|
||||
@@ -32,7 +32,6 @@ import torch
|
||||
|
||||
import cutlass_api
|
||||
|
||||
|
||||
pytestmark = pytest.mark.arch("80")
|
||||
|
||||
|
||||
@@ -69,7 +68,7 @@ def epi(accum, C, alpha, beta):
|
||||
except ValueError as e:
|
||||
assert "F" in str(e)
|
||||
else:
|
||||
assert False, "Failed to catch missing keyword"
|
||||
raise AssertionError("Failed to catch missing keyword")
|
||||
|
||||
|
||||
def test_extra_keywords():
|
||||
@@ -92,4 +91,4 @@ def epi(accum, C, alpha, beta):
|
||||
except ValueError as e:
|
||||
assert "gamma" in str(e)
|
||||
else:
|
||||
assert False, "Failed to catch extra keyword"
|
||||
raise AssertionError("Failed to catch extra keyword")
|
||||
|
||||
151
python/cutlass_api/test/unit/test_metadata.py
Normal file
151
python/cutlass_api/test/unit/test_metadata.py
Normal file
@@ -0,0 +1,151 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import cutlass
|
||||
from cutlass import cute
|
||||
|
||||
import cutlass_api
|
||||
from cutlass_api.arguments import ElementwiseArguments
|
||||
from cutlass_api.config import GlobalOptions
|
||||
from cutlass_api.metadata import (
|
||||
ElementwiseOperandsMetadata,
|
||||
TensorAttributes,
|
||||
)
|
||||
|
||||
|
||||
class NoopKernelForTesting(cutlass_api.providers.cutedsl.kernel.CuteDslKernel):
|
||||
@cute.jit
|
||||
def impl(self, A, B, out, stream):
|
||||
cute.printf("Called kernel from host successfully!")
|
||||
return
|
||||
|
||||
def compile(self, args: ElementwiseArguments):
|
||||
stream = cute.runtime.make_fake_stream()
|
||||
return self.cute_compile(self.impl, args.A, args.B, args.out, stream)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
args: ElementwiseArguments,
|
||||
compiled_artifact,
|
||||
stream,
|
||||
workspace=None,
|
||||
):
|
||||
self.cute_run(compiled_artifact, args.A, args.B, args.out, stream)
|
||||
|
||||
def generate_kernels(_ignored_filter, _ignored_epilogue_args, _ignored_cc):
|
||||
attrs = TensorAttributes(
|
||||
stride=(0, 1),
|
||||
dtype=cutlass.Float16,
|
||||
divisibility=8,
|
||||
)
|
||||
metadata = cutlass_api.KernelMetadata(
|
||||
kernel_name="NoopKernelForTesting",
|
||||
kernel_class=NoopKernelForTesting,
|
||||
min_cc=80,
|
||||
operands=ElementwiseOperandsMetadata(
|
||||
A=attrs,
|
||||
B=attrs,
|
||||
out=attrs,
|
||||
),
|
||||
)
|
||||
return [NoopKernelForTesting(metadata)]
|
||||
|
||||
|
||||
kernel = NoopKernelForTesting.generate_kernels(None, None, None)[0]
|
||||
|
||||
|
||||
def test_perfectly_aligned():
|
||||
divisibility = kernel.metadata.operands.A.divisibility
|
||||
A, B, out = [
|
||||
torch.randn(divisibility, divisibility * 2, dtype=torch.float16, device="cuda")
|
||||
for _ in range(3)
|
||||
]
|
||||
args = ElementwiseArguments(A=A, B=B, out=out)
|
||||
kernel.run(args)
|
||||
|
||||
|
||||
def test_overaligned():
|
||||
A, B, out = [
|
||||
torch.randn(1024, 1024, dtype=torch.float16, device="cuda") for _ in range(3)
|
||||
]
|
||||
args = ElementwiseArguments(A=A, B=B, out=out)
|
||||
kernel.run(args)
|
||||
|
||||
|
||||
def _check_misaligned_args(use_tvm_ffi: bool, error_match_string: str, **tensors):
|
||||
"""Helper to test various misalignment errors are properly caught.
|
||||
|
||||
With TVM-FFI:
|
||||
args creation may succeed, but kernel.supports must fail.
|
||||
TVM-FFI should still catch errors if user bypasses supports.
|
||||
Without TVM-FFI:
|
||||
error must be caught early during argument creation itself.
|
||||
"""
|
||||
GlobalOptions().use_tvm_ffi = use_tvm_ffi
|
||||
|
||||
if use_tvm_ffi:
|
||||
args = ElementwiseArguments(**tensors)
|
||||
assert not kernel.supports(args), "Unsupported args should be rejected"
|
||||
|
||||
with pytest.raises(Exception, match=error_match_string):
|
||||
kernel.run(args, assume_supported_args=True)
|
||||
else:
|
||||
with pytest.raises(Exception, match=error_match_string):
|
||||
ElementwiseArguments(**tensors)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tvm_ffi", [True, False])
|
||||
def test_underaligned(use_tvm_ffi: bool):
|
||||
divisibility = kernel.metadata.operands.A.divisibility
|
||||
A, B, out = [
|
||||
torch.randn(
|
||||
divisibility + divisibility // 2,
|
||||
divisibility + divisibility // 4,
|
||||
dtype=torch.float16,
|
||||
device="cuda",
|
||||
)
|
||||
for _ in range(3)
|
||||
]
|
||||
_check_misaligned_args(use_tvm_ffi, "divisible", A=A, B=B, out=out)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tvm_ffi", [True, False])
|
||||
def test_ptr_misaligned(use_tvm_ffi: bool):
|
||||
rows = kernel.metadata.operands.A.divisibility * 4
|
||||
cols = rows
|
||||
offset = 117
|
||||
|
||||
A = torch.randn(rows * cols + offset, dtype=torch.float16, device="cuda")
|
||||
B = torch.randn(rows, cols, dtype=torch.float16, device="cuda")
|
||||
out = torch.randn(rows, cols, dtype=torch.float16, device="cuda")
|
||||
A_offset = torch.as_strided(A[offset:], (rows, cols), (cols, 1))
|
||||
|
||||
_check_misaligned_args(use_tvm_ffi, "align", A=A_offset, B=B, out=out)
|
||||
Reference in New Issue
Block a user