mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-19 22:38:56 +00:00
2026-01-12 updates
This commit is contained in:
@@ -69,8 +69,8 @@ See CUTLASS's [Compatibility section](https://github.com/NVIDIA/cutlass?tab=read
|
||||
### Current support
|
||||
|
||||
* Dense GEMM: `out = A @ B`
|
||||
- Compute capabilities: 100, 103 (WIP to expand to more)
|
||||
- Input precisions (A and B must be of same type): F16, BF16, TF32, INT8
|
||||
- Compute capabilities: 100, 103
|
||||
- Input precisions (A and B must be of same type): F16, BF16, INT8
|
||||
- Output precisions: F32, F16, BF16, INT32
|
||||
- Epilogue operations:
|
||||
- Auxiliary load of tensors equal in rank and shape of `out`
|
||||
@@ -79,10 +79,20 @@ See CUTLASS's [Compatibility section](https://github.com/NVIDIA/cutlass?tab=read
|
||||
- Tensor-tensor elementwise or tensor-scalar addition, multiplication, subtraction, division
|
||||
- Elementwise tensor exponent, relu, sigmoid, tanh
|
||||
- Note: Partial support exists on CC 80/89/90 (limited dtypes/tilings coverage)
|
||||
* Block-scaled dense GEMM:
|
||||
- Compute capabilities: 100, 103
|
||||
- Input precisions: MXFP8 (F8E4M3, F8E5M2)
|
||||
- Output precisions: F32, F16, BF16, F8E4M3, F8E5M2
|
||||
- Scale factor precisions: F8E8M0
|
||||
* Grouped GEMM
|
||||
- Contiguous offset 2d-3d grouped GEMM: `(TotalM, K) @ (G, K, N) -> (TotalM, N)`
|
||||
- Compute capability: 100
|
||||
- Input precisions: F16, BF16, F8E4M3, F8E5M2
|
||||
- Output precisions: F32, F16, BF16
|
||||
|
||||
* Planned additions
|
||||
* Block-scaled GEMMs
|
||||
* Grouped GEMMs
|
||||
* Block-scaled GEMMs (additional precisions)
|
||||
* Grouped GEMMs (additional variants and precisions)
|
||||
* Additional epilogue operations
|
||||
* Reductions
|
||||
* Row/column broadcasts
|
||||
|
||||
@@ -28,9 +28,11 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
import copy
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field, fields
|
||||
from typing import TYPE_CHECKING, get_type_hints
|
||||
from typing import TYPE_CHECKING, Any, get_type_hints
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
@@ -45,10 +47,10 @@ import cutlass.cute as cute
|
||||
|
||||
from cutlass_api.fusion import EmptyTensor, trace, trace_in_out
|
||||
from cutlass_api.fusion.library import LayoutType
|
||||
from cutlass_api.library import ScaleMode, ScaleSwizzleMode
|
||||
from cutlass_api.typing import NumericLike, TensorLike
|
||||
from cutlass_api.utils import (
|
||||
TensorWrapper,
|
||||
add_batch_mode,
|
||||
is_torch_tensor,
|
||||
to_cutlass_type,
|
||||
)
|
||||
@@ -172,10 +174,6 @@ class EpilogueArguments:
|
||||
"""Returns the list of names of the input and output parameters of the epilogue"""
|
||||
return list(self.tensors.keys())
|
||||
|
||||
def add_batch_modes(self):
|
||||
for name in self.tensors:
|
||||
self.tensors[name] = add_batch_mode(self.tensors[name])
|
||||
|
||||
def to_tensor_wrappers(self, permute: list[int] | None = None):
|
||||
"""Converts the input and output parameters of the epilogue to TensorWrappers"""
|
||||
for k, v in self.tensors.items():
|
||||
@@ -215,6 +213,42 @@ class EpilogueArguments:
|
||||
)
|
||||
|
||||
|
||||
def convert_to_internal_types(caller, metadata: dict[str, Any] = None):
|
||||
"""
|
||||
Converts fields of the caller to internal types. Current fields that are converted:
|
||||
* TensorLike -> TensorWrapper
|
||||
* NumericLike -> cutlass.Numeric
|
||||
* Classes that implement _convert_to_internal_types -> their internal types
|
||||
|
||||
:param caller: The caller object to convert the fields of
|
||||
:type caller: Any
|
||||
:param metadata: Additional metadata to be used for conversion
|
||||
:type metadata: dict[str, Any] | None
|
||||
"""
|
||||
|
||||
type_hints = get_type_hints(type(caller))
|
||||
for f in fields(caller):
|
||||
hint = type_hints.get(f.name)
|
||||
value = getattr(caller, f.name)
|
||||
|
||||
global_metadata = {} if metadata is None else copy.deepcopy(metadata)
|
||||
global_metadata.update(f.metadata)
|
||||
|
||||
if hint is TensorLike:
|
||||
# Find all fields that are annotated as TensorLike,
|
||||
# and wrap them in TensorWrapper
|
||||
setattr(caller, f.name, TensorWrapper(value, **global_metadata))
|
||||
elif hint is NumericLike:
|
||||
# Find all fields that are annotated as NumericLike,
|
||||
# and convert them to cutlass.Numeric
|
||||
setattr(caller, f.name, to_cutlass_type(value))
|
||||
elif hasattr(value, "_convert_to_internal_types"):
|
||||
# If the field is an instance of a class that implements
|
||||
# _convert_to_internal_types, convert it to internal types
|
||||
value._convert_to_internal_types(metadata=global_metadata)
|
||||
setattr(caller, f.name, value)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RuntimeArguments:
|
||||
"""
|
||||
@@ -231,7 +265,6 @@ class RuntimeArguments:
|
||||
:type performance: Optional[PerformanceControls]
|
||||
"""
|
||||
|
||||
epilogue: EpilogueArguments | None = field(default=None, kw_only=True)
|
||||
performance: PerformanceControls | None = field(default=None, kw_only=True)
|
||||
|
||||
def _validate(self):
|
||||
@@ -241,24 +274,92 @@ class RuntimeArguments:
|
||||
"""
|
||||
|
||||
def __post_init__(self):
|
||||
self._validate()
|
||||
self._convert_to_internal_types()
|
||||
convert_to_internal_types(self)
|
||||
|
||||
def _convert_to_internal_types(self):
|
||||
"""
|
||||
Converts all fields to their internal types.
|
||||
"""
|
||||
type_hints = get_type_hints(type(self))
|
||||
for f in fields(self):
|
||||
hint = type_hints.get(f.name)
|
||||
value = getattr(self, f.name)
|
||||
|
||||
# Find all fields that are annotated as TensorLike,
|
||||
# and wrap them in TensorWrapper
|
||||
if hint is TensorLike:
|
||||
setattr(self, f.name, TensorWrapper(value))
|
||||
elif hint is NumericLike:
|
||||
setattr(self, f.name, to_cutlass_type(value))
|
||||
class KernelOperand(ABC):
|
||||
"""
|
||||
Base class for all operands to kernels.
|
||||
"""
|
||||
|
||||
def _convert_to_internal_types(self, metadata: dict[str, Any] = None):
|
||||
convert_to_internal_types(self, metadata=metadata)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DenseTensor(KernelOperand):
|
||||
"""
|
||||
Representation of a dense tensor operand.
|
||||
"""
|
||||
|
||||
tensor: TensorLike
|
||||
|
||||
def __getattr__(self, attr: str) -> Any:
|
||||
if hasattr(self.tensor, attr):
|
||||
return getattr(self.tensor, attr)
|
||||
else:
|
||||
raise AttributeError(
|
||||
f"'{self.__class__.__name__}' object has no attribute '{attr}'"
|
||||
)
|
||||
|
||||
|
||||
def kernel_operand_or_default(operand: TensorLike | KernelOperand) -> KernelOperand:
|
||||
"""
|
||||
If operand is already a KernelOperand, return it. Otherwise, wrap it in a DenseTensor.
|
||||
This is used for convenience interfaces to avoid needing to wrap dense tensors in a DenseTensor
|
||||
class.
|
||||
|
||||
:param operand: The operand to convert
|
||||
:type operand: TensorLike | KernelOperand
|
||||
:return: The operand wrapped in a DenseTensor if it is a TensorLike, otherwise the operand itself
|
||||
:rtype: KernelOperand
|
||||
"""
|
||||
if isinstance(operand, TensorLike):
|
||||
return DenseTensor(operand)
|
||||
else:
|
||||
return operand
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScaledTensor(KernelOperand):
|
||||
"""
|
||||
Representation of a scaled tensor operand. This includes:
|
||||
* A base tensor. This is a KernelOperand subclass.
|
||||
* A tensor containing scale factors
|
||||
* Scale mode and swizzle mode
|
||||
|
||||
An example of its creation is the following:
|
||||
```python
|
||||
A = torch.rand(...)
|
||||
scale_A = torch.rand(...)
|
||||
arg = ScaledTensor(A, scale_A, ScaleMode.Blockwise1x32, ScaleSwizzleMode.Swizzle32x4x4)
|
||||
```
|
||||
"""
|
||||
|
||||
base: KernelOperand
|
||||
scale: DenseTensor
|
||||
mode: ScaleMode | tuple[int, ...]
|
||||
swizzle: ScaleSwizzleMode
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base: KernelOperand | TensorLike,
|
||||
scale: DenseTensor | TensorLike,
|
||||
mode: ScaleMode | tuple[int, ...],
|
||||
swizzle: ScaleSwizzleMode,
|
||||
):
|
||||
self.base = kernel_operand_or_default(base)
|
||||
self.scale = kernel_operand_or_default(scale)
|
||||
self.mode = mode
|
||||
self.swizzle = swizzle
|
||||
|
||||
def __getattr__(self, attr: str) -> Any:
|
||||
if hasattr(self.base, attr):
|
||||
return getattr(self.base, attr)
|
||||
else:
|
||||
raise AttributeError(
|
||||
f"'{self.__class__.__name__}' object has no attribute '{attr}'"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -273,24 +374,70 @@ class GemmArguments(RuntimeArguments):
|
||||
N: Number of columns in B and out
|
||||
|
||||
:param A: Input tensor A of shape (L, M, K) or (M, K)
|
||||
:type A: TensorLike
|
||||
:type A: KernelOperand
|
||||
:param B: Input tensor B of shape (L, K, N) or (K, N)
|
||||
:type B: TensorLike
|
||||
:type B: KernelOperand
|
||||
:param out: Output tensor C of shape (L, M, N) or (M, N)
|
||||
:type out: TensorLike
|
||||
:type out: KernelOperand
|
||||
:param accumulator_type: Data type of the accumulator
|
||||
:type accumulator_type: NumericLike
|
||||
"""
|
||||
|
||||
A: TensorLike
|
||||
B: TensorLike
|
||||
out: TensorLike
|
||||
A: KernelOperand
|
||||
B: KernelOperand
|
||||
out: KernelOperand
|
||||
accumulator_type: NumericLike
|
||||
epilogue: EpilogueArguments | None
|
||||
|
||||
def to_operands_metadata(self) -> GemmOperandsMetadata:
|
||||
from cutlass_api.metadata import GemmOperandsMetadata
|
||||
def __init__(
|
||||
self,
|
||||
A: TensorLike | KernelOperand,
|
||||
B: TensorLike | KernelOperand,
|
||||
out: TensorLike | KernelOperand,
|
||||
accumulator_type: NumericLike,
|
||||
epilogue: EpilogueArguments | None = None,
|
||||
):
|
||||
"""
|
||||
Construct a GemmArguments object.
|
||||
|
||||
return GemmOperandsMetadata.from_args(self)
|
||||
For convenience, construction for `A`, `B`, and `out` that are operands
|
||||
to a dense GEMM can be passed in without wrapping them in `DenseTensor`.
|
||||
|
||||
```python
|
||||
GemmArguments(A, B, out, accumulator_type)
|
||||
# is equivalent to:
|
||||
GemmArguments(DenseTensor(A), DenseTensor(B), DenseTensor(out), accumulator_type)
|
||||
```
|
||||
|
||||
Other operand types must explicitly wrap tensors in a `KernelOperand` subclass.
|
||||
For example, a scaled GEMM can be constructed as:
|
||||
```python
|
||||
GemmArguments(
|
||||
ScaledTensor(A, ScaleATensor, scale_mode, scale_swizzle),
|
||||
ScaledTensor(B, ScaleBTensor, scale_mode, scale_swizzle),
|
||||
out, # No need to wrap out in a `DenseTensor`
|
||||
accumulator_type,
|
||||
)
|
||||
```
|
||||
|
||||
:param A: Input tensor A of shape (L, M, K) or (M, K)
|
||||
:type A: TensorLike | KernelOperand
|
||||
:param B: Input tensor B of shape (L, K, N) or (K, N)
|
||||
:type B: TensorLike | KernelOperand
|
||||
:param out: Output tensor C of shape (L, M, N) or (M, N)
|
||||
:type out: TensorLike | KernelOperand
|
||||
:param accumulator_type: Data type of the accumulator
|
||||
:type accumulator_type: NumericLike
|
||||
:param epilogue: Optional custom epilogue fusion to be performed after the GEMM.
|
||||
:type epilogue: Optional[EpilogueArguments]
|
||||
"""
|
||||
self.A = kernel_operand_or_default(A)
|
||||
self.B = kernel_operand_or_default(B)
|
||||
self.out = kernel_operand_or_default(out)
|
||||
|
||||
self.accumulator_type = accumulator_type
|
||||
self.epilogue = epilogue
|
||||
super().__init__()
|
||||
|
||||
def _validate(self):
|
||||
"""
|
||||
@@ -329,13 +476,79 @@ class GemmArguments(RuntimeArguments):
|
||||
f"out & A must have the same rank and batch dimension (if any). out shape (L, M, N): {self.out.shape}, A shape (L, M, K): {self.A.shape}"
|
||||
)
|
||||
|
||||
if isinstance(self.epilogue, EpilogueArguments):
|
||||
def _convert_epilogue(self):
|
||||
"""Converts the epilogue to an internal representation using internal types."""
|
||||
if self.epilogue is not None:
|
||||
L = self.A.shape[0] if len(self.A.shape) == 3 else 1
|
||||
M, N = self.A.shape[-2], self.B.shape[-1]
|
||||
accum_shape = (L, M, N)
|
||||
self.epilogue.trace(accum_shape, self.accumulator_type)
|
||||
self.epilogue.to_tensor_wrappers()
|
||||
|
||||
def __post_init__(self):
|
||||
self._validate()
|
||||
self._convert_epilogue()
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class GroupedGemmArguments(RuntimeArguments):
|
||||
"""
|
||||
Arguments for a grouped GEMM operation. A grouped GEMM performs a series
|
||||
of independent GEMM operations.
|
||||
|
||||
The most basic formulation of a grouped GEMM is one in which one has lists
|
||||
of A tensors and lists of B tensors, and computes the following:
|
||||
```python
|
||||
for i in range(problems_in_group):
|
||||
out[i] = A[i] @ B[i]
|
||||
```
|
||||
|
||||
Note that each constituent GEMM operation can have different sizes.
|
||||
|
||||
Though the abstract formulation of a grouped GEMM treats operands as parts
|
||||
of a list, in practice, the operands are often packed together contiguously in memory.
|
||||
In this case, an `offsets` tensor delineates the ending positions of each problem in the group.
|
||||
|
||||
Currently-supported variants of a grouped GEMM are:
|
||||
* Contiguous offset 2D/3D:
|
||||
* A is a tensor of shape (TotalM, K) or (1, TotalM, K)
|
||||
* B is a tensor of shape (problems_in_group, K, N)
|
||||
* out is a tensor of shape (TotalM, N) or (1, TotalM, N)
|
||||
* offsets is a tensor delineating the ending positions of each problem in the group.
|
||||
```python
|
||||
start = 0
|
||||
for i in range(problems_in_group):
|
||||
end = offsets[i]
|
||||
out[start:end, :] = A[start:end, :] @ B[i, :, :]
|
||||
start = end
|
||||
```
|
||||
"""
|
||||
|
||||
A: KernelOperand
|
||||
B: KernelOperand
|
||||
out: KernelOperand
|
||||
accumulator_type: NumericLike
|
||||
offsets: KernelOperand = field(metadata={"alignment_bytes": 4})
|
||||
epilogue: EpilogueArguments | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
A: TensorLike | KernelOperand,
|
||||
B: TensorLike | KernelOperand,
|
||||
out: TensorLike | KernelOperand,
|
||||
accumulator_type: NumericLike,
|
||||
offsets: TensorLike | KernelOperand,
|
||||
epilogue: EpilogueArguments | None = None,
|
||||
):
|
||||
self.A = kernel_operand_or_default(A)
|
||||
self.B = kernel_operand_or_default(B)
|
||||
self.out = kernel_operand_or_default(out)
|
||||
self.accumulator_type = accumulator_type
|
||||
self.offsets = kernel_operand_or_default(offsets)
|
||||
self.epilogue = epilogue
|
||||
super().__init__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ElementwiseArguments(RuntimeArguments):
|
||||
@@ -343,21 +556,27 @@ class ElementwiseArguments(RuntimeArguments):
|
||||
Arguments needed for an elementwise operation.
|
||||
|
||||
:param A: The input tensor A.
|
||||
:type A: TensorLike
|
||||
:type A: TensorLike | KernelOperand
|
||||
:param B: The input tensor B.
|
||||
:type B: TensorLike
|
||||
:type B: TensorLike | KernelOperand
|
||||
:param out: The output tensor.
|
||||
:type out: TensorLike
|
||||
:type out: TensorLike | KernelOperand
|
||||
"""
|
||||
|
||||
A: TensorLike
|
||||
B: TensorLike
|
||||
out: TensorLike
|
||||
A: KernelOperand
|
||||
B: KernelOperand
|
||||
out: KernelOperand
|
||||
|
||||
def to_operands_metadata(self) -> ElementwiseOperandsMetadata:
|
||||
from cutlass_api.metadata import ElementwiseOperandsMetadata
|
||||
|
||||
return ElementwiseOperandsMetadata.from_args(self)
|
||||
def __init__(
|
||||
self,
|
||||
A: TensorLike | KernelOperand,
|
||||
B: TensorLike | KernelOperand,
|
||||
out: TensorLike | KernelOperand,
|
||||
):
|
||||
self.A = kernel_operand_or_default(A)
|
||||
self.B = kernel_operand_or_default(B)
|
||||
self.out = kernel_operand_or_default(out)
|
||||
super().__init__()
|
||||
|
||||
def _validate(self):
|
||||
"""
|
||||
@@ -371,3 +590,7 @@ class ElementwiseArguments(RuntimeArguments):
|
||||
raise ValueError(
|
||||
f"out.shape ({self.out.shape}) must be equal to A.shape ({self.A.shape})"
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
self._validate()
|
||||
super().__post_init__()
|
||||
|
||||
82
python/cutlass_api/cutlass_api/library.py
Normal file
82
python/cutlass_api/cutlass_api/library.py
Normal file
@@ -0,0 +1,82 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import enum
|
||||
from enum import auto as enum_auto
|
||||
from typing import Self
|
||||
|
||||
|
||||
class ScaleMode(enum.Enum):
|
||||
"""
|
||||
Type of scaling used for scaled operations. Each scaling enum corresponds to
|
||||
a tuple indicating the units the scale covers in the (L, M, K) modes.
|
||||
|
||||
For example, Blockwise1x16 corresponds to (1, 1, 16), meaning the scale operates
|
||||
over 16 elements in the K mode.
|
||||
"""
|
||||
|
||||
Blockwise1x16 = (1, 1, 16)
|
||||
Blockwise1x32 = (1, 1, 32)
|
||||
|
||||
@staticmethod
|
||||
def compare(mode1: Self | tuple[int, ...], mode2: Self | tuple[int, ...]) -> bool:
|
||||
if isinstance(mode1, ScaleMode) and isinstance(mode2, ScaleMode):
|
||||
return mode1.value == mode2.value
|
||||
|
||||
# One of the two modes is a tuple. Use tuple comparison and allow
|
||||
# for different lengths as long as long as the longer tuple contains
|
||||
# only 1s for the extra leading positions (i.e., allow (1,1,16) == (1, 16))
|
||||
tuple1 = mode1 if isinstance(mode1, tuple) else mode1.value
|
||||
tuple2 = mode2 if isinstance(mode2, tuple) else mode2.value
|
||||
|
||||
if len(tuple1) == len(tuple2):
|
||||
return tuple1 == tuple2
|
||||
elif len(tuple1) < len(tuple2):
|
||||
padding = (1,) * (len(tuple2) - len(tuple1))
|
||||
return padding + tuple1 == tuple2
|
||||
else:
|
||||
padding = (1,) * (len(tuple1) - len(tuple2))
|
||||
return padding + tuple2 == tuple1
|
||||
|
||||
def __eq__(self, other: Self | tuple[int, ...]) -> bool:
|
||||
return ScaleMode.compare(self, other)
|
||||
|
||||
def __ne__(self, other: Self | tuple[int, ...]) -> bool:
|
||||
return not ScaleMode.compare(self, other)
|
||||
|
||||
def __getitem__(self, index: int) -> int:
|
||||
return self.value[index]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.value)
|
||||
|
||||
|
||||
class ScaleSwizzleMode(enum.Enum):
|
||||
"""Swizzle mode used for scale factors"""
|
||||
SwizzleNone = enum_auto()
|
||||
Swizzle32x4x4 = enum_auto()
|
||||
@@ -90,7 +90,7 @@ class Manifest:
|
||||
return False
|
||||
return True
|
||||
|
||||
epilogue_args = None if args is None else args.epilogue
|
||||
epilogue_args = getattr(args, "epilogue", None)
|
||||
kernels = [
|
||||
k
|
||||
# Generate kernels from all providers
|
||||
|
||||
@@ -38,13 +38,17 @@ if TYPE_CHECKING:
|
||||
import cutlass.cute as cute
|
||||
|
||||
from cutlass_api.arguments import (
|
||||
DenseTensor,
|
||||
ElementwiseArguments,
|
||||
EpilogueArguments,
|
||||
GemmArguments,
|
||||
GroupedGemmArguments,
|
||||
KernelOperand,
|
||||
RuntimeArguments,
|
||||
ScaledTensor,
|
||||
)
|
||||
from cutlass_api.library import ScaleMode, ScaleSwizzleMode
|
||||
from cutlass_api.status import Status
|
||||
from cutlass_api.utils import TensorWrapper
|
||||
|
||||
|
||||
def _convert_stride(shape: tuple[int, ...], stride: tuple[int, ...]) -> tuple[int, ...]:
|
||||
@@ -65,6 +69,11 @@ def _convert_stride(shape: tuple[int, ...], stride: tuple[int, ...]) -> tuple[in
|
||||
new_stride.append(0)
|
||||
else:
|
||||
new_stride.append(stride[i])
|
||||
|
||||
# If the size of the shape is 1, assign the last mode a stride of 1.
|
||||
if all(x == 0 for x in new_stride):
|
||||
new_stride[-1] = 1
|
||||
|
||||
return new_stride
|
||||
|
||||
|
||||
@@ -110,7 +119,7 @@ class TensorAttributes:
|
||||
:param dtype: The data type of the tensor.
|
||||
:type dtype: cutlass.Numeric
|
||||
:param stride: The stride of the tensor.
|
||||
:type stride: tuple[int, ...]
|
||||
:type stride: tuple[int, ...] | None
|
||||
:param divisibility: The divisibility requirement on a tensor's stride & shape elements
|
||||
:type divisibility: int
|
||||
:param ptr_alignment_bytes: The alignment of the tensor's data pointer, in bytes.
|
||||
@@ -119,7 +128,7 @@ class TensorAttributes:
|
||||
"""
|
||||
|
||||
dtype: cutlass.Numeric # F32, F16, etc.
|
||||
stride: tuple[int, ...]
|
||||
stride: tuple[int, ...] | None
|
||||
divisibility: int
|
||||
ptr_alignment_bytes: int
|
||||
|
||||
@@ -137,19 +146,26 @@ class TensorAttributes:
|
||||
(divisibility * dtype.width) // 8
|
||||
)
|
||||
|
||||
def supports(self, operand: TensorWrapper | Self) -> Status:
|
||||
def supports(self, operand: KernelOperand | Self) -> Status:
|
||||
"""
|
||||
Checks whether the provided args satisfy the properties described by
|
||||
these TensorAttributes.
|
||||
|
||||
:param operand: The operand to check support for.
|
||||
:type operand: TensorWrapper | Self
|
||||
:type operand: KernelOperand | Self
|
||||
|
||||
:return: Whether the provided operand satisfies the properties described by
|
||||
these TensorAttributes.
|
||||
:rtype: Status
|
||||
"""
|
||||
if isinstance(operand, TensorWrapper):
|
||||
if not isinstance(operand, KernelOperand) and not isinstance(
|
||||
operand, TensorAttributes
|
||||
):
|
||||
return Status.fail(
|
||||
f"Expected KernelOperand or TensorAttributes, got {type(operand)}"
|
||||
)
|
||||
|
||||
if isinstance(operand, KernelOperand):
|
||||
if operand.element_type != self.dtype:
|
||||
return Status.fail(
|
||||
f"Expected element type {self.dtype}, got {operand.element_type}"
|
||||
@@ -162,37 +178,37 @@ class TensorAttributes:
|
||||
# Normalize the operand stride to have 0's where operand.shape is 1
|
||||
normalized_operand_stride = _convert_stride(operand.shape, operand.stride)
|
||||
|
||||
# If the operand is batched, the expected stride is the same as the self.stride
|
||||
if len(self.stride) == len(normalized_operand_stride):
|
||||
expected_stride = self.stride
|
||||
# We can try if strides match without the batch mode
|
||||
elif len(self.stride) - 1 == len(normalized_operand_stride):
|
||||
expected_stride = self.stride[1:]
|
||||
if self.stride is None:
|
||||
contiguous_dim = normalized_operand_stride.index(1)
|
||||
else:
|
||||
return Status.fail(
|
||||
f"Expected tensor with strides of rank {len(self.stride)} (batched) or {len(self.stride) - 1} (unbatched), got {len(normalized_operand_stride)} ({normalized_operand_stride})"
|
||||
)
|
||||
# If the operand is batched, the expected stride is the same as the self.stride
|
||||
if len(self.stride) == len(normalized_operand_stride):
|
||||
expected_stride = self.stride
|
||||
# We can try if strides match without the batch mode
|
||||
elif len(self.stride) - 1 == len(normalized_operand_stride):
|
||||
expected_stride = self.stride[1:]
|
||||
else:
|
||||
return Status.fail(
|
||||
f"Expected tensor with strides of rank {len(self.stride)} (batched) or {len(self.stride) - 1} (unbatched), got {len(normalized_operand_stride)} ({normalized_operand_stride})"
|
||||
)
|
||||
|
||||
# Strides are considered compatible if:
|
||||
# 1. They have the same rank (checked above)
|
||||
# 2. The leading mode is the same (or both are all 0)
|
||||
all_zeros = all(x == 0 for x in expected_stride) and all(
|
||||
x == 0 for x in normalized_operand_stride
|
||||
)
|
||||
# Strides are considered compatible if:
|
||||
# 1. They have the same rank (checked above)
|
||||
# 2. The leading mode is the same
|
||||
|
||||
# When setting stride from args, any modes of stride 1 and shape 1
|
||||
# are changed to have stride 0. Thus, there will only be one mode
|
||||
# with stride 1.
|
||||
contiguous_dim = expected_stride.index(1)
|
||||
if not all_zeros and normalized_operand_stride[contiguous_dim] != 1:
|
||||
return Status.fail(
|
||||
f"Expected stride[{contiguous_dim}] to be 1, got "
|
||||
f"{normalized_operand_stride[contiguous_dim]} "
|
||||
f"(strides: {normalized_operand_stride})"
|
||||
)
|
||||
# When setting stride from args, any modes of stride 1 and shape 1
|
||||
# are changed to have stride 0. Thus, there will only be one mode
|
||||
# with stride 1.
|
||||
contiguous_dim = expected_stride.index(1)
|
||||
if normalized_operand_stride[contiguous_dim] != 1:
|
||||
return Status.fail(
|
||||
f"Expected stride[{contiguous_dim}] to be 1, got "
|
||||
f"{normalized_operand_stride[contiguous_dim]} "
|
||||
f"(strides: {normalized_operand_stride})"
|
||||
)
|
||||
|
||||
# Check that divisibility constraints are met
|
||||
if isinstance(operand, TensorWrapper):
|
||||
if isinstance(operand, KernelOperand):
|
||||
if not _is_tuple_aligned(
|
||||
normalized_operand_stride, self.divisibility, contiguous_dim
|
||||
):
|
||||
@@ -209,7 +225,7 @@ class TensorAttributes:
|
||||
|
||||
# Check data ptr alignment, if available
|
||||
if (
|
||||
isinstance(operand, TensorWrapper)
|
||||
isinstance(operand, KernelOperand)
|
||||
and operand.data_ptr % self.ptr_alignment_bytes != 0
|
||||
):
|
||||
return Status.fail(
|
||||
@@ -218,23 +234,104 @@ class TensorAttributes:
|
||||
|
||||
return Status.success()
|
||||
|
||||
@staticmethod
|
||||
def from_tensor(tensor) -> Self:
|
||||
"""
|
||||
Create a TensorAttributes from a tensor.
|
||||
|
||||
:param tensor: The tensor to create a TensorAttributes from.
|
||||
:type tensor: cute.Tensor
|
||||
@dataclass
|
||||
class DenseTensorAttributes:
|
||||
"""
|
||||
Description of a dense tensor. This includes the data type, stride, and alignment.
|
||||
"""
|
||||
|
||||
:return: The TensorAttributes corresponding to the provided tensor.
|
||||
:rtype: TensorAttributes
|
||||
"""
|
||||
stride = _convert_stride(tensor.shape, tensor.stride)
|
||||
max_divisibility = _get_max_pow2_divisibility(tensor.shape, stride)
|
||||
return TensorAttributes(
|
||||
dtype=tensor.element_type, stride=stride, divisibility=max_divisibility
|
||||
_tensor: TensorAttributes
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dtype: cutlass.Numeric,
|
||||
stride: tuple[int, ...],
|
||||
divisibility: int,
|
||||
ptr_alignment_bytes: int | None = None,
|
||||
):
|
||||
self._tensor = TensorAttributes(
|
||||
dtype, stride, divisibility, ptr_alignment_bytes
|
||||
)
|
||||
|
||||
def supports(self, operand: DenseTensor | Self) -> Status:
|
||||
"""
|
||||
Checks whether the provided args satisfy the properties described by
|
||||
these TensorAttributes.
|
||||
|
||||
:param operand: The operand to check support for.
|
||||
:type operand: DenseTensor | Self
|
||||
|
||||
:return: Whether the provided operand satisfies the properties described by
|
||||
these TensorAttributes.
|
||||
:rtype: Status
|
||||
"""
|
||||
if not isinstance(operand, DenseTensor) and not isinstance(
|
||||
operand, DenseTensorAttributes
|
||||
):
|
||||
return Status.fail(
|
||||
f"Expected DenseTensor or DenseTensorAttributes, got {type(operand)}"
|
||||
)
|
||||
|
||||
return self._tensor.supports(operand)
|
||||
|
||||
def __getattr__(self, attr: str) -> Any:
|
||||
if hasattr(self._tensor, attr):
|
||||
return getattr(self._tensor, attr)
|
||||
else:
|
||||
raise AttributeError(
|
||||
f"'{self.__class__.__name__}' object has no attribute '{attr}'"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScaledTensorAttributes:
|
||||
"""
|
||||
Description of a scaled tensor. This includes the base tensor and the
|
||||
scale tensor + scale mode + scale swizzle combination.
|
||||
"""
|
||||
|
||||
base: TensorAttributes
|
||||
scale: DenseTensorAttributes
|
||||
mode: ScaleMode | tuple[int, ...]
|
||||
swizzle: ScaleSwizzleMode
|
||||
|
||||
def supports(self, operand: ScaledTensor | Self) -> Status:
|
||||
"""
|
||||
Checks whether the provided args satisfy the properties described by
|
||||
these ScaledTensorAttributes.
|
||||
"""
|
||||
if not isinstance(operand, ScaledTensor) and not isinstance(
|
||||
operand, ScaledTensorAttributes
|
||||
):
|
||||
return Status.fail(
|
||||
f"Expected ScaledTensor or ScaledTensorAttributes, got {type(operand)}"
|
||||
)
|
||||
|
||||
if not (status := self.base.supports(operand.base)):
|
||||
return status
|
||||
|
||||
if not (status := self.scale.supports(operand.scale)):
|
||||
return status
|
||||
|
||||
if not ScaleMode.compare(self.mode, operand.mode):
|
||||
return Status.fail(f"Expected scale mode {self.mode}, got {operand.mode}")
|
||||
|
||||
if self.swizzle != operand.swizzle:
|
||||
return Status.fail(
|
||||
f"Expected scale swizzle mode {self.swizzle}, got {operand.swizzle}"
|
||||
)
|
||||
|
||||
return Status.success()
|
||||
|
||||
def __getattr__(self, attr: str) -> Any:
|
||||
if hasattr(self.base, attr):
|
||||
return getattr(self.base, attr)
|
||||
else:
|
||||
raise AttributeError(
|
||||
f"'{self.__class__.__name__}' object has no attribute '{attr}'"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OperandsMetadata(ABC):
|
||||
@@ -326,11 +423,11 @@ class GemmOperandsMetadata(OperandsMetadata):
|
||||
"""
|
||||
Metadata for the operands of a GEMM operation.
|
||||
|
||||
:param A: Metadata for the input tensor A.
|
||||
:param A: Metadata for operand `A` of the GEMM.
|
||||
:type A: TensorAttributes
|
||||
:param B: Metadata for the input tensor B.
|
||||
:param B: Metadata for operand `B` of the GEMM.
|
||||
:type B: TensorAttributes
|
||||
:param out: Metadata for the output tensor.
|
||||
:param out: Metadata for operand `out` of the GEMM.
|
||||
:type out: TensorAttributes
|
||||
:param accumulator_type: The data type of the accumulator tensor.
|
||||
:type accumulator_type: cutlass.Numeric
|
||||
@@ -372,23 +469,65 @@ class GemmOperandsMetadata(OperandsMetadata):
|
||||
|
||||
return Status.success()
|
||||
|
||||
@staticmethod
|
||||
def from_args(args: GemmArguments) -> Self:
|
||||
"""
|
||||
Create a GemmOperandsMetadata from a GemmArguments.
|
||||
|
||||
:param args: The GemmArguments to create a GemmOperandsMetadata from.
|
||||
:type args: GemmArguments
|
||||
@dataclass
|
||||
class GroupedGemmOperandsMetadata(OperandsMetadata):
|
||||
"""
|
||||
Metadata for the operands of a grouped GEMM operation.
|
||||
|
||||
:return: The GemmOperandsMetadata corresponding to the provided GemmArguments.
|
||||
:rtype: GemmOperandsMetadata
|
||||
:param A: Metadata for operand `A` of the grouped GEMM.
|
||||
:type A: TensorAttributes
|
||||
:param B: Metadata for operand `B` of the grouped GEMM.
|
||||
:type B: TensorAttributes
|
||||
:param out: Metadata for operand `out` of the grouped GEMM.
|
||||
:type out: TensorAttributes
|
||||
:param accumulator_type: The data type of the accumulator tensor.
|
||||
:type accumulator_type: cutlass.Numeric
|
||||
:param offsets: Metadata for operand `offsets` of the grouped GEMM.
|
||||
:type offsets: TensorAttributes
|
||||
"""
|
||||
|
||||
A: TensorAttributes
|
||||
B: TensorAttributes
|
||||
out: TensorAttributes
|
||||
offsets: TensorAttributes
|
||||
accumulator_type: cutlass.Numeric
|
||||
|
||||
def supports(self, other: GroupedGemmArguments | Self) -> Status:
|
||||
"""
|
||||
return GemmOperandsMetadata(
|
||||
A=TensorAttributes.from_tensor(args.A),
|
||||
B=TensorAttributes.from_tensor(args.B),
|
||||
out=TensorAttributes.from_tensor(args.out),
|
||||
accumulator_type=args.accumulator_type,
|
||||
)
|
||||
Checks whether the provided args satisfy the properties described by
|
||||
the operands in this metadata.
|
||||
|
||||
:param other: The arguments to check support for.
|
||||
:type other: GroupedGemmArguments | Self
|
||||
|
||||
:return: Whether the provided args satisfy the properties described by
|
||||
the operands in this metadata.
|
||||
:rtype: Status
|
||||
"""
|
||||
if isinstance(other, RuntimeArguments) and not isinstance(
|
||||
other, GroupedGemmArguments
|
||||
):
|
||||
return Status.fail(f"Expected GroupedGemmArguments, got {type(other)}")
|
||||
|
||||
if not (status := self.A.supports(other.A)):
|
||||
return Status.fail(f"Operand `A` is unsupported: {status.error}")
|
||||
|
||||
if not (status := self.B.supports(other.B)):
|
||||
return Status.fail(f"Operand `B` is unsupported: {status.error}")
|
||||
|
||||
if not (status := self.out.supports(other.out)):
|
||||
return Status.fail(f"Operand `out` is unsupported: {status.error}")
|
||||
|
||||
if not (status := self.offsets.supports(other.offsets)):
|
||||
return Status.fail(f"Operand `offsets` is unsupported: {status.error}")
|
||||
|
||||
if self.accumulator_type != other.accumulator_type:
|
||||
return Status.fail(
|
||||
f"Expected accumulator type {self.accumulator_type}, got {other.accumulator_type}"
|
||||
)
|
||||
|
||||
return Status.success()
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -436,23 +575,6 @@ class ElementwiseOperandsMetadata(OperandsMetadata):
|
||||
|
||||
return Status.success()
|
||||
|
||||
@staticmethod
|
||||
def from_args(args: ElementwiseArguments) -> Self:
|
||||
"""
|
||||
Create a ElementwiseOperandsMetadata from a ElementwiseArguments.
|
||||
|
||||
:param args: The ElementwiseArguments to create a ElementwiseOperandsMetadata from.
|
||||
:type args: ElementwiseArguments
|
||||
|
||||
:return: The ElementwiseOperandsMetadata corresponding to the provided ElementwiseArguments.
|
||||
:rtype: ElementwiseOperandsMetadata
|
||||
"""
|
||||
return ElementwiseOperandsMetadata(
|
||||
A=TensorAttributes.from_tensor(args.A),
|
||||
B=TensorAttributes.from_tensor(args.B),
|
||||
out=TensorAttributes.from_tensor(args.out),
|
||||
)
|
||||
|
||||
|
||||
class EpilogueMetadata:
|
||||
def __init__(self, epilogue_args: EpilogueArguments):
|
||||
@@ -535,7 +657,11 @@ class KernelMetadata:
|
||||
if not (status := supports_or_none(self.design, args.performance, "design")):
|
||||
return status
|
||||
|
||||
if not (status := supports_or_none(self.epilogue, args.epilogue, "epilogue")):
|
||||
if not (
|
||||
status := supports_or_none(
|
||||
self.epilogue, getattr(args, "epilogue", None), "epilogue"
|
||||
)
|
||||
):
|
||||
return status
|
||||
|
||||
return Status.success()
|
||||
|
||||
@@ -38,9 +38,9 @@ import cutlass.cute as cute
|
||||
from cutlass_api.arguments import ElementwiseArguments, EpilogueArguments
|
||||
from cutlass_api.artifact import CompiledArtifact
|
||||
from cutlass_api.metadata import (
|
||||
DenseTensorAttributes,
|
||||
ElementwiseOperandsMetadata,
|
||||
KernelMetadata,
|
||||
TensorAttributes,
|
||||
)
|
||||
from cutlass_api.providers.cutedsl import CuTeDSLProvider
|
||||
from cutlass_api.providers.cutedsl.kernel import CuteDslKernel
|
||||
@@ -65,7 +65,9 @@ class ElementwiseAddKernel(CuteDslKernel):
|
||||
|
||||
def compile(self, args: ElementwiseArguments, cc: int = None) -> CompiledArtifact:
|
||||
stream = cutlass.cute.runtime.make_fake_stream()
|
||||
compiled_kernel = self.cute_compile(self.impl, args.A, args.B, args.out, stream)
|
||||
compiled_kernel = self.cute_compile(
|
||||
self.impl, args.A.tensor, args.B.tensor, args.out.tensor, stream
|
||||
)
|
||||
return CompiledArtifact(compiled_kernel, self)
|
||||
|
||||
def _run(
|
||||
@@ -77,7 +79,9 @@ class ElementwiseAddKernel(CuteDslKernel):
|
||||
) -> None:
|
||||
stream = to_cuda_stream(stream)
|
||||
compiled_kernel = compiled_artifact.compiled_obj
|
||||
self.cute_run(compiled_kernel, args.A, args.B, args.out, stream)
|
||||
self.cute_run(
|
||||
compiled_kernel, args.A.tensor, args.B.tensor, args.out.tensor, stream
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def generate_kernels(
|
||||
@@ -111,13 +115,13 @@ class ElementwiseAddKernel(CuteDslKernel):
|
||||
)
|
||||
|
||||
operands = ElementwiseOperandsMetadata(
|
||||
A=TensorAttributes(
|
||||
A=DenseTensorAttributes(
|
||||
dtype=dtype, stride=stride_A, divisibility=divisibility
|
||||
),
|
||||
B=TensorAttributes(
|
||||
B=DenseTensorAttributes(
|
||||
dtype=dtype, stride=stride_B, divisibility=divisibility
|
||||
),
|
||||
out=TensorAttributes(
|
||||
out=DenseTensorAttributes(
|
||||
dtype=dtype, stride=stride_out, divisibility=divisibility
|
||||
),
|
||||
)
|
||||
|
||||
@@ -31,4 +31,6 @@
|
||||
# ruff: noqa: F401
|
||||
import cutlass_api.providers.cutedsl.gemm.sm100_static_persistent
|
||||
import cutlass_api.providers.cutedsl.gemm.sm100_static_persistent_efc
|
||||
import cutlass_api.providers.cutedsl.gemm.sm100_dense_blockscaled_static_persistent
|
||||
import cutlass_api.providers.cutedsl.gemm.sm80_tensorop_gemm
|
||||
import cutlass_api.providers.cutedsl.gemm.sm100_contiguous_offset_2d3d_dense_gemm
|
||||
|
||||
@@ -0,0 +1,602 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
from cutlass.cutlass_dsl import (
|
||||
Boolean,
|
||||
Integer,
|
||||
Int32,
|
||||
min,
|
||||
extract_mlir_values,
|
||||
new_from_mlir_values,
|
||||
dsl_user_op,
|
||||
const_expr,
|
||||
)
|
||||
from cutlass._mlir import ir
|
||||
import cutlass.cute as cute
|
||||
|
||||
|
||||
class WorkTileInfo:
|
||||
"""A class to represent information about a work tile.
|
||||
|
||||
:ivar tile_idx: The index of the tile.
|
||||
:type tile_idx: cute.Coord
|
||||
:ivar is_valid_tile: Whether the tile is valid.
|
||||
:type is_valid_tile: Boolean
|
||||
"""
|
||||
|
||||
def __init__(self, tile_idx: cute.Coord, is_valid_tile: Boolean):
|
||||
self._tile_idx = tile_idx
|
||||
self._is_valid_tile = Boolean(is_valid_tile)
|
||||
|
||||
def __extract_mlir_values__(self) -> list[ir.Value]:
|
||||
values = extract_mlir_values(self.tile_idx)
|
||||
values.extend(extract_mlir_values(self.is_valid_tile))
|
||||
return values
|
||||
|
||||
def __new_from_mlir_values__(self, values: list[ir.Value]) -> "WorkTileInfo":
|
||||
assert len(values) == 4
|
||||
new_tile_idx = new_from_mlir_values(self._tile_idx, values[:-1])
|
||||
new_is_valid_tile = new_from_mlir_values(self._is_valid_tile, [values[-1]])
|
||||
return WorkTileInfo(new_tile_idx, new_is_valid_tile)
|
||||
|
||||
@property
|
||||
def is_valid_tile(self) -> Boolean:
|
||||
"""Check latest tile returned by the scheduler is valid or not. Any scheduling
|
||||
requests after all tasks completed will return an invalid tile.
|
||||
|
||||
:return: The validity of the tile.
|
||||
:rtype: Boolean
|
||||
"""
|
||||
return self._is_valid_tile
|
||||
|
||||
@property
|
||||
def tile_idx(self) -> cute.Coord:
|
||||
"""
|
||||
Get the index of the tile.
|
||||
|
||||
:return: The index of the tile.
|
||||
:rtype: cute.Coord
|
||||
"""
|
||||
return self._tile_idx
|
||||
|
||||
|
||||
class PersistentTileSchedulerParams:
|
||||
"""A class to represent parameters for a persistent tile scheduler.
|
||||
|
||||
This class is designed to manage and compute the layout of clusters and tiles
|
||||
in a batched gemm problem.
|
||||
|
||||
:ivar cluster_shape_mn: Shape of the cluster in (m, n) dimensions (K dimension cta count must be 1).
|
||||
:type cluster_shape_mn: tuple
|
||||
:ivar problem_layout_ncluster_mnl: Layout of the problem in terms of
|
||||
number of clusters in (m, n, l) dimensions.
|
||||
:type problem_layout_ncluster_mnl: cute.Layout
|
||||
"""
|
||||
|
||||
@dsl_user_op
|
||||
def __init__(
|
||||
self,
|
||||
problem_shape_ntile_mnl: cute.Shape,
|
||||
cluster_shape_mnk: cute.Shape,
|
||||
swizzle_size: int = 1,
|
||||
raster_along_m: bool = True,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
"""
|
||||
Initializes the PersistentTileSchedulerParams with the given parameters.
|
||||
|
||||
:param problem_shape_ntile_mnl: The shape of the problem in terms of
|
||||
number of CTA (Cooperative Thread Array) in (m, n, l) dimensions.
|
||||
:type problem_shape_ntile_mnl: cute.Shape
|
||||
:param cluster_shape_mnk: The shape of the cluster in (m, n) dimensions.
|
||||
:type cluster_shape_mnk: cute.Shape
|
||||
:param swizzle_size: Swizzling size in the unit of cluster. 1 means no swizzle
|
||||
:type swizzle_size: int
|
||||
:param raster_along_m: Rasterization order of clusters. Only used when swizzle_size > 1.
|
||||
True means along M, false means along N.
|
||||
:type raster_along_m: bool
|
||||
|
||||
:raises ValueError: If cluster_shape_k is not 1.
|
||||
"""
|
||||
|
||||
if cluster_shape_mnk[2] != 1:
|
||||
raise ValueError(f"unsupported cluster_shape_k {cluster_shape_mnk[2]}")
|
||||
if swizzle_size < 1:
|
||||
raise ValueError(f"expect swizzle_size >= 1, but get {swizzle_size}")
|
||||
|
||||
self.problem_shape_ntile_mnl = problem_shape_ntile_mnl
|
||||
# cluster_shape_mnk is kept for reconstruction
|
||||
self._cluster_shape_mnk = cluster_shape_mnk
|
||||
self.cluster_shape_mn = cluster_shape_mnk[:2]
|
||||
self.swizzle_size = swizzle_size
|
||||
self.raster_along_m = raster_along_m
|
||||
self._loc = loc
|
||||
|
||||
# By default, we follow m major (col-major) raster order, so make a col-major layout
|
||||
self.problem_layout_ncluster_mnl = cute.make_layout(
|
||||
cute.ceil_div(
|
||||
self.problem_shape_ntile_mnl, cluster_shape_mnk[:2], loc=loc, ip=ip
|
||||
),
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
# Apply swizzle if swizzle_size > 1
|
||||
if swizzle_size > 1:
|
||||
problem_shape_ncluster_mnl = cute.round_up(
|
||||
self.problem_layout_ncluster_mnl.shape,
|
||||
(1, swizzle_size, 1) if raster_along_m else (swizzle_size, 1, 1),
|
||||
)
|
||||
|
||||
if raster_along_m:
|
||||
self.problem_layout_ncluster_mnl = cute.make_layout(
|
||||
(
|
||||
problem_shape_ncluster_mnl[0],
|
||||
(swizzle_size, problem_shape_ncluster_mnl[1] // swizzle_size),
|
||||
problem_shape_ncluster_mnl[2],
|
||||
),
|
||||
stride=(
|
||||
swizzle_size,
|
||||
(1, swizzle_size * problem_shape_ncluster_mnl[0]),
|
||||
problem_shape_ncluster_mnl[0] * problem_shape_ncluster_mnl[1],
|
||||
),
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
else:
|
||||
self.problem_layout_ncluster_mnl = cute.make_layout(
|
||||
(
|
||||
(swizzle_size, problem_shape_ncluster_mnl[0] // swizzle_size),
|
||||
problem_shape_ncluster_mnl[1],
|
||||
problem_shape_ncluster_mnl[2],
|
||||
),
|
||||
stride=(
|
||||
(1, swizzle_size * problem_shape_ncluster_mnl[1]),
|
||||
swizzle_size,
|
||||
problem_shape_ncluster_mnl[0] * problem_shape_ncluster_mnl[1],
|
||||
),
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
# Create FastDivmod divisors (only when swizzle_size == 1 for correctness)
|
||||
# FastDivmod assumes simple col-major layout, incompatible with swizzled layouts
|
||||
if swizzle_size == 1:
|
||||
problem_layout_size = cute.size(
|
||||
self.problem_layout_ncluster_mnl, loc=loc, ip=ip
|
||||
)
|
||||
cluster_count_m = self.problem_layout_ncluster_mnl.shape[0]
|
||||
cluster_count_n = self.problem_layout_ncluster_mnl.shape[1]
|
||||
|
||||
# batch_fdd: Used to map linear_idx to work_unit_id (handles persistent scheduling)
|
||||
self.batch_fdd = cute.fast_divmod_create_divisor(
|
||||
problem_layout_size, loc=loc, ip=ip
|
||||
)
|
||||
|
||||
if raster_along_m:
|
||||
cluster_count_major = cluster_count_m
|
||||
cluster_count_minor = cluster_count_n
|
||||
else:
|
||||
cluster_count_major = cluster_count_n
|
||||
cluster_count_minor = cluster_count_m
|
||||
|
||||
# cluster_shape_major_fdd: Used to decode work_unit_id to cluster coordinates
|
||||
self.cluster_shape_major_fdd = cute.fast_divmod_create_divisor(
|
||||
cluster_count_major, loc=loc, ip=ip
|
||||
)
|
||||
|
||||
# cluster_shape_minor_fdd: Used for the second level decomposition
|
||||
self.cluster_shape_minor_fdd = cute.fast_divmod_create_divisor(
|
||||
cluster_count_minor, loc=loc, ip=ip
|
||||
)
|
||||
else:
|
||||
# FastDivmod not applicable with swizzling, set to None
|
||||
self.batch_fdd = None
|
||||
self.cluster_shape_major_fdd = None
|
||||
self.cluster_shape_minor_fdd = None
|
||||
|
||||
def __extract_mlir_values__(self):
|
||||
values, self._values_pos = [], []
|
||||
for obj in [
|
||||
self.problem_shape_ntile_mnl,
|
||||
self._cluster_shape_mnk,
|
||||
self.swizzle_size,
|
||||
self.raster_along_m,
|
||||
]:
|
||||
obj_values = extract_mlir_values(obj)
|
||||
values += obj_values
|
||||
self._values_pos.append(len(obj_values))
|
||||
|
||||
# Add FastDivmod divisors to MLIR values for Host->Device transfer
|
||||
# Only add non-None values to avoid MLIR type errors
|
||||
fastdivmod_values = []
|
||||
fastdivmod_indices = [] # Track which FastDivmod objects are present
|
||||
|
||||
for i, (fdd_name, fdd_obj) in enumerate(
|
||||
[
|
||||
("batch_fdd", self.batch_fdd),
|
||||
("cluster_shape_major_fdd", self.cluster_shape_major_fdd),
|
||||
("cluster_shape_minor_fdd", self.cluster_shape_minor_fdd),
|
||||
]
|
||||
):
|
||||
if fdd_obj is not None:
|
||||
# Extract MLIR values from FastDivmodDivisor objects
|
||||
fdd_values = extract_mlir_values(fdd_obj)
|
||||
fastdivmod_values.extend(fdd_values)
|
||||
fastdivmod_indices.append(i)
|
||||
|
||||
values += fastdivmod_values
|
||||
self._values_pos.append(
|
||||
len(fastdivmod_indices)
|
||||
) # Store count of FastDivmod objects, not values
|
||||
self._fastdivmod_indices = fastdivmod_indices # Store for reconstruction
|
||||
|
||||
return values
|
||||
|
||||
def __new_from_mlir_values__(self, values):
|
||||
obj_list = []
|
||||
values_copy = list(values) # Make a copy to avoid modifying original
|
||||
|
||||
# Reconstruct original objects from MLIR values
|
||||
for obj, n_items in zip(
|
||||
[
|
||||
self.problem_shape_ntile_mnl,
|
||||
self._cluster_shape_mnk,
|
||||
self.swizzle_size,
|
||||
self.raster_along_m,
|
||||
],
|
||||
self._values_pos[:-1], # Exclude FastDivmod count
|
||||
):
|
||||
obj_list.append(new_from_mlir_values(obj, values_copy[:n_items]))
|
||||
values_copy = values_copy[n_items:]
|
||||
|
||||
# Create new params object by calling __init__ with reconstructed values
|
||||
# This properly recreates layouts and other derived attributes in the device context
|
||||
new_params = PersistentTileSchedulerParams(*(tuple(obj_list)), loc=self._loc)
|
||||
|
||||
# Restore FastDivmod divisors from remaining values
|
||||
fdd_names = ["batch_fdd", "cluster_shape_major_fdd", "cluster_shape_minor_fdd"]
|
||||
|
||||
if hasattr(self, "_fastdivmod_indices") and len(self._fastdivmod_indices) > 0:
|
||||
# Override the FastDivmod divisors created by __init__ with reconstructed ones
|
||||
for j, original_index in enumerate(self._fastdivmod_indices):
|
||||
fdd_name = fdd_names[original_index]
|
||||
# Get the original FastDivmodDivisor object
|
||||
original_fdd = getattr(self, fdd_name)
|
||||
if original_fdd is not None and j < len(values_copy):
|
||||
# Each FastDivmodDivisor has 1 MLIR value
|
||||
reconstructed_fdd = new_from_mlir_values(
|
||||
original_fdd, [values_copy[j]]
|
||||
)
|
||||
setattr(new_params, fdd_name, reconstructed_fdd)
|
||||
|
||||
return new_params
|
||||
|
||||
@dsl_user_op
|
||||
def get_grid_shape(
|
||||
self, max_active_clusters: Int32, *, loc=None, ip=None
|
||||
) -> Tuple[Integer, Integer, Integer]:
|
||||
"""
|
||||
Computes the grid shape based on the maximum active clusters allowed.
|
||||
|
||||
:param max_active_clusters: The maximum number of active clusters that
|
||||
can run in one wave.
|
||||
:type max_active_clusters: Int32
|
||||
|
||||
:return: A tuple containing the grid shape in (m, n, persistent_clusters).
|
||||
- m: self.cluster_shape_m.
|
||||
- n: self.cluster_shape_n.
|
||||
- persistent_clusters: Number of persistent clusters that can run.
|
||||
"""
|
||||
|
||||
# Total ctas in problem size
|
||||
num_ctas_mnl = tuple(
|
||||
cute.size(x) * y
|
||||
for x, y in zip(
|
||||
self.problem_layout_ncluster_mnl.shape, self.cluster_shape_mn
|
||||
)
|
||||
) + (self.problem_layout_ncluster_mnl.shape[2],)
|
||||
|
||||
num_ctas_in_problem = cute.size(num_ctas_mnl, loc=loc, ip=ip)
|
||||
|
||||
num_ctas_per_cluster = cute.size(self.cluster_shape_mn, loc=loc, ip=ip)
|
||||
# Total ctas that can run in one wave
|
||||
num_ctas_per_wave = max_active_clusters * num_ctas_per_cluster
|
||||
|
||||
num_persistent_ctas = min(num_ctas_in_problem, num_ctas_per_wave)
|
||||
num_persistent_clusters = num_persistent_ctas // num_ctas_per_cluster
|
||||
|
||||
return (*self.cluster_shape_mn, num_persistent_clusters)
|
||||
|
||||
|
||||
class StaticPersistentTileScheduler:
|
||||
"""A scheduler for static persistent tile execution in CUTLASS/CuTe kernels.
|
||||
|
||||
:ivar params: Tile schedule related params, including cluster shape and problem_layout_ncluster_mnl
|
||||
:type params: PersistentTileSchedulerParams
|
||||
:ivar num_persistent_clusters: Number of persistent clusters that can be launched
|
||||
:type num_persistent_clusters: Int32
|
||||
:ivar cta_id_in_cluster: ID of the CTA within its cluster
|
||||
:type cta_id_in_cluster: cute.Coord
|
||||
:ivar _num_tiles_executed: Counter for executed tiles
|
||||
:type _num_tiles_executed: Int32
|
||||
:ivar _current_work_linear_idx: Current cluster index
|
||||
:type _current_work_linear_idx: Int32
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params: PersistentTileSchedulerParams,
|
||||
num_persistent_clusters: Int32,
|
||||
current_work_linear_idx: Int32,
|
||||
cta_id_in_cluster: cute.Coord,
|
||||
num_tiles_executed: Int32,
|
||||
):
|
||||
"""
|
||||
Initializes the StaticPersistentTileScheduler with the given parameters.
|
||||
|
||||
:param params: Tile schedule related params, including cluster shape and problem_layout_ncluster_mnl.
|
||||
:type params: PersistentTileSchedulerParams
|
||||
:param num_persistent_clusters: Number of persistent clusters that can be launched.
|
||||
:type num_persistent_clusters: Int32
|
||||
:param current_work_linear_idx: Current cluster index.
|
||||
:type current_work_linear_idx: Int32
|
||||
:param cta_id_in_cluster: ID of the CTA within its cluster.
|
||||
:type cta_id_in_cluster: cute.Coord
|
||||
:param num_tiles_executed: Counter for executed tiles.
|
||||
:type num_tiles_executed: Int32
|
||||
"""
|
||||
self.params = params
|
||||
self.num_persistent_clusters = num_persistent_clusters
|
||||
self._current_work_linear_idx = current_work_linear_idx
|
||||
self.cta_id_in_cluster = cta_id_in_cluster
|
||||
self._num_tiles_executed = num_tiles_executed
|
||||
|
||||
def __extract_mlir_values__(self) -> list[ir.Value]:
|
||||
values = extract_mlir_values(self.num_persistent_clusters)
|
||||
values.extend(extract_mlir_values(self._current_work_linear_idx))
|
||||
values.extend(extract_mlir_values(self.cta_id_in_cluster))
|
||||
values.extend(extract_mlir_values(self._num_tiles_executed))
|
||||
|
||||
# CRITICAL: Also extract FastDivmod divisors from params
|
||||
values.extend(extract_mlir_values(self.params))
|
||||
|
||||
return values
|
||||
|
||||
def __new_from_mlir_values__(
|
||||
self, values: list[ir.Value]
|
||||
) -> "StaticPersistentTileScheduler":
|
||||
assert len(values) >= 6
|
||||
new_num_persistent_clusters = new_from_mlir_values(
|
||||
self.num_persistent_clusters, [values[0]]
|
||||
)
|
||||
new_current_work_linear_idx = new_from_mlir_values(
|
||||
self._current_work_linear_idx, [values[1]]
|
||||
)
|
||||
new_cta_id_in_cluster = new_from_mlir_values(
|
||||
self.cta_id_in_cluster, values[2:5]
|
||||
)
|
||||
new_num_tiles_executed = new_from_mlir_values(
|
||||
self._num_tiles_executed, [values[5]]
|
||||
)
|
||||
|
||||
# Reconstruct params with FastDivmod divisors
|
||||
params_values = values[6:] # Remaining values are from params
|
||||
new_params = new_from_mlir_values(self.params, params_values)
|
||||
|
||||
return StaticPersistentTileScheduler(
|
||||
new_params, # Use reconstructed params with FastDivmod divisors
|
||||
new_num_persistent_clusters,
|
||||
new_current_work_linear_idx,
|
||||
new_cta_id_in_cluster,
|
||||
new_num_tiles_executed,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@dsl_user_op
|
||||
def create(
|
||||
params: PersistentTileSchedulerParams,
|
||||
block_idx: Tuple[Integer, Integer, Integer],
|
||||
grid_dim: Tuple[Integer, Integer, Integer],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
"""Initialize the static persistent tile scheduler.
|
||||
|
||||
:param params: Parameters for the persistent
|
||||
tile scheduler.
|
||||
:type params: PersistentTileSchedulerParams
|
||||
:param block_idx: The 3d block index in the format (bidx, bidy, bidz).
|
||||
:type block_idx: Tuple[Integer, Integer, Integer]
|
||||
:param grid_dim: The 3d grid dimensions for kernel launch.
|
||||
:type grid_dim: Tuple[Integer, Integer, Integer]
|
||||
|
||||
:return: A StaticPersistentTileScheduler object.
|
||||
:rtype: StaticPersistentTileScheduler
|
||||
"""
|
||||
|
||||
# Calculate the number of persistent clusters by dividing the total grid size
|
||||
# by the number of CTAs per cluster
|
||||
num_persistent_clusters = cute.size(grid_dim, loc=loc, ip=ip) // cute.size(
|
||||
params.cluster_shape_mn, loc=loc, ip=ip
|
||||
)
|
||||
|
||||
bidx, bidy, bidz = block_idx
|
||||
|
||||
# Initialize workload index equals to the cluster index in the grid
|
||||
current_work_linear_idx = Int32(bidz)
|
||||
|
||||
# CTA id in the cluster
|
||||
cta_id_in_cluster = (
|
||||
Int32(bidx % params.cluster_shape_mn[0]),
|
||||
Int32(bidy % params.cluster_shape_mn[1]),
|
||||
Int32(0),
|
||||
)
|
||||
# Initialize number of tiles executed to zero
|
||||
num_tiles_executed = Int32(0)
|
||||
return StaticPersistentTileScheduler(
|
||||
params,
|
||||
num_persistent_clusters,
|
||||
current_work_linear_idx,
|
||||
cta_id_in_cluster,
|
||||
num_tiles_executed,
|
||||
)
|
||||
|
||||
# called by host
|
||||
@staticmethod
|
||||
def get_grid_shape(
|
||||
params: PersistentTileSchedulerParams,
|
||||
max_active_clusters: Int32,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> Tuple[Integer, Integer, Integer]:
|
||||
"""Calculates the grid shape to be launched on GPU using problem shape,
|
||||
threadblock shape, and active cluster size.
|
||||
|
||||
:param params: Parameters for grid shape calculation.
|
||||
:type params: PersistentTileSchedulerParams
|
||||
:param max_active_clusters: Maximum active clusters allowed.
|
||||
:type max_active_clusters: Int32
|
||||
|
||||
:return: The calculated 3d grid shape.
|
||||
:rtype: Tuple[Integer, Integer, Integer]
|
||||
"""
|
||||
|
||||
return params.get_grid_shape(max_active_clusters, loc=loc, ip=ip)
|
||||
|
||||
# private method
|
||||
def _get_current_work_for_linear_idx(
|
||||
self, current_work_linear_idx: Int32, *, loc=None, ip=None
|
||||
) -> WorkTileInfo:
|
||||
"""Compute current tile coord given current_work_linear_idx and cta_id_in_cluster.
|
||||
|
||||
:param current_work_linear_idx: The linear index of the current work.
|
||||
:type current_work_linear_idx: Int32
|
||||
|
||||
:return: An object containing information about the current tile coordinates
|
||||
and validity status.
|
||||
:rtype: WorkTileInfo
|
||||
"""
|
||||
|
||||
is_valid = current_work_linear_idx < cute.size(
|
||||
self.params.problem_layout_ncluster_mnl, loc=loc, ip=ip
|
||||
)
|
||||
|
||||
# Choose coordinate calculation method based on swizzle configuration
|
||||
if self.params.swizzle_size == 1:
|
||||
# Use FastDivmod optimization for non-swizzled layouts
|
||||
cur_cluster_coord = self._get_cluster_work_idx_with_fastdivmod(
|
||||
current_work_linear_idx, loc=loc, ip=ip
|
||||
)
|
||||
else:
|
||||
# Use get_flat_coord for swizzled layouts (FastDivmod doesn't support them)
|
||||
cur_cluster_coord = self.params.problem_layout_ncluster_mnl.get_flat_coord(
|
||||
current_work_linear_idx, loc=loc, ip=ip
|
||||
)
|
||||
|
||||
# cur_tile_coord is a tuple of i32 values
|
||||
cur_tile_coord = tuple(
|
||||
Int32(x) * Int32(z) + Int32(y)
|
||||
for x, y, z in zip(
|
||||
cur_cluster_coord,
|
||||
self.cta_id_in_cluster,
|
||||
(*self.params.cluster_shape_mn, Int32(1)),
|
||||
)
|
||||
)
|
||||
|
||||
return WorkTileInfo(cur_tile_coord, is_valid)
|
||||
|
||||
def _get_cluster_work_idx_with_fastdivmod(
|
||||
self, current_work_linear_idx: Int32, *, loc=None, ip=None
|
||||
) -> Tuple[Int32, Int32, Int32]:
|
||||
"""
|
||||
FastDivmod optimized CLUSTER coordinate calculation.
|
||||
|
||||
CRITICAL: This should mimic problem_layout_ncluster_mnl.get_hier_coord()
|
||||
which returns CLUSTER coordinates, not tile coordinates!
|
||||
|
||||
:param current_work_linear_idx: Linear index in the work space
|
||||
:type current_work_linear_idx: Int32
|
||||
:return: Cluster coordinates (m, n, l) or None if FastDivmod not available
|
||||
:rtype: Tuple[Int32, Int32, Int32] or None
|
||||
"""
|
||||
|
||||
# Step 1: Handle persistent scheduling - map linear_idx to work_unit_id
|
||||
work_iteration, work_unit_id = divmod(
|
||||
current_work_linear_idx, self.params.batch_fdd
|
||||
)
|
||||
|
||||
# Step 2: Decode work_unit_id using FastDivmod objects
|
||||
# The layout structure is: problem_layout_ncluster_mnl has shape (cluster_count_m, cluster_count_n, batch_count)
|
||||
# work_unit_id needs to be decomposed into (batch_l, cluster_minor, cluster_major) in little-endian order
|
||||
|
||||
# First, get cluster_major using cluster_shape_major_fdd
|
||||
cluster_minor_batch, cluster_major = divmod(
|
||||
work_unit_id, self.params.cluster_shape_major_fdd
|
||||
)
|
||||
|
||||
# Then decode cluster_minor_batch to get cluster_minor and batch_l using FastDivmod
|
||||
batch_l, cluster_minor = divmod(
|
||||
cluster_minor_batch, self.params.cluster_shape_minor_fdd
|
||||
)
|
||||
|
||||
if self.params.raster_along_m:
|
||||
cluster_m = cluster_major
|
||||
cluster_n = cluster_minor
|
||||
else:
|
||||
cluster_m = cluster_minor
|
||||
cluster_n = cluster_major
|
||||
|
||||
return (cluster_m, cluster_n, batch_l)
|
||||
|
||||
@dsl_user_op
|
||||
def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
|
||||
return self._get_current_work_for_linear_idx(
|
||||
self._current_work_linear_idx, loc=loc, ip=ip
|
||||
)
|
||||
|
||||
@dsl_user_op
|
||||
def initial_work_tile_info(self, *, loc=None, ip=None) -> WorkTileInfo:
|
||||
return self.get_current_work(loc=loc, ip=ip)
|
||||
|
||||
@dsl_user_op
|
||||
def advance_to_next_work(self, *, advance_count: int = 1, loc=None, ip=None):
|
||||
self._current_work_linear_idx += Int32(advance_count) * Int32(
|
||||
self.num_persistent_clusters
|
||||
)
|
||||
self._num_tiles_executed += Int32(1)
|
||||
|
||||
@property
|
||||
def num_tiles_executed(self) -> Int32:
|
||||
return self._num_tiles_executed
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,347 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import itertools
|
||||
from collections.abc import Callable, Generator
|
||||
|
||||
import cutlass
|
||||
import cutlass.utils as utils
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cute.runtime import make_ptr
|
||||
|
||||
import cutlass_api
|
||||
from cutlass_api.arguments import (
|
||||
EpilogueArguments,
|
||||
GroupedGemmArguments,
|
||||
)
|
||||
from cutlass_api.artifact import CompiledArtifact
|
||||
from cutlass_api.metadata import (
|
||||
DesignMetadata,
|
||||
EpilogueMetadata,
|
||||
GroupedGemmOperandsMetadata,
|
||||
KernelMetadata,
|
||||
Sm100DesignMetadata,
|
||||
)
|
||||
from cutlass_api.providers.cutedsl import CuTeDSLProvider
|
||||
from cutlass_api.providers.cutedsl.kernel import CuteDslKernel
|
||||
from cutlass_api.providers.cutedsl.utils import get_max_active_clusters
|
||||
from cutlass_api.status import Status
|
||||
from cutlass_api.utils import (
|
||||
ceil_div,
|
||||
round_up,
|
||||
strides_to_layout_string,
|
||||
to_cuda_stream,
|
||||
tuple_to_string,
|
||||
)
|
||||
|
||||
from .implementations.sm100_contiguous_offset_2d3d_dense_gemm_impl import (
|
||||
ContiguousOffset2D3DGemmDenseKernelImpl,
|
||||
)
|
||||
from .utils import (
|
||||
tensor_rank_2_or_3,
|
||||
tensor_reduction_mode_matches,
|
||||
tensor_output_shape_matches,
|
||||
)
|
||||
|
||||
|
||||
@CuTeDSLProvider.register
|
||||
class ContiguousOffset2D3DGemmDenseKernel(CuteDslKernel):
|
||||
def __init__(self, metadata: KernelMetadata):
|
||||
super().__init__(metadata)
|
||||
mma_tiler_mn = (metadata.design.tile_shape[0], metadata.design.tile_shape[1])
|
||||
cluster_shape_mn = (
|
||||
metadata.design.cluster_shape[0],
|
||||
metadata.design.cluster_shape[1],
|
||||
)
|
||||
self.impl = ContiguousOffset2D3DGemmDenseKernelImpl(
|
||||
metadata.operands.accumulator_type,
|
||||
metadata.design.use_2cta_mma,
|
||||
mma_tiler_mn,
|
||||
cluster_shape_mn,
|
||||
)
|
||||
|
||||
def _supports(self, args: GroupedGemmArguments) -> Status:
|
||||
if not (status := tensor_reduction_mode_matches(args)):
|
||||
return status
|
||||
|
||||
if not (status := tensor_output_shape_matches(args)):
|
||||
return status
|
||||
|
||||
if len(args.A.shape) == 3 and args.A.shape[0] != 1:
|
||||
return Status.fail(
|
||||
"Operand A must have batch size 1."
|
||||
)
|
||||
|
||||
if len(args.out.shape) == 3 and args.out.shape[0] != 1:
|
||||
return Status.fail("out must have batch size 1.")
|
||||
|
||||
ValidM, N = args.out.shape[-2:]
|
||||
K = args.A.shape[-1]
|
||||
group_count = args.B.shape[0] if len(args.B.shape) == 3 else 1
|
||||
|
||||
if args.A.shape[-2] != ValidM or args.A.shape[-1] != K:
|
||||
return Status.fail(f"A must have shape ({ValidM}, {K}). Got {args.A.shape}")
|
||||
if args.B.shape != (group_count, K, N):
|
||||
return Status.fail(
|
||||
f"B must have shape ({group_count}, {K}, {N}). Got {args.B.shape}"
|
||||
)
|
||||
if args.out.shape[-2] != ValidM or args.out.shape[-1] != N:
|
||||
return Status.fail(
|
||||
f"out must have shape ({ValidM}, {N}). Got {args.out.shape}"
|
||||
)
|
||||
|
||||
#
|
||||
# Check offsets
|
||||
#
|
||||
if len(args.offsets.shape) != 1:
|
||||
return Status.fail(f"offsets must be a 1D tensor. Got {args.offsets.shape}")
|
||||
|
||||
if args.offsets.numel() != group_count:
|
||||
return Status.fail(
|
||||
f"offsets must have {group_count} elements when offset mode is End and group count is {group_count}."
|
||||
)
|
||||
|
||||
return Status.success()
|
||||
|
||||
def compile(self, args: GroupedGemmArguments, cc: int = None) -> CompiledArtifact:
|
||||
stream = cutlass.cute.runtime.make_fake_stream()
|
||||
max_active_clusters = get_max_active_clusters(self.impl.cluster_shape_mn)
|
||||
|
||||
compiled_kernel = self.cute_compile(
|
||||
self.impl,
|
||||
args.A.tensor,
|
||||
args.B.tensor,
|
||||
args.out.tensor,
|
||||
args.offsets.tensor,
|
||||
max_active_clusters,
|
||||
stream,
|
||||
)
|
||||
return CompiledArtifact(compiled_kernel, self)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
args: GroupedGemmArguments,
|
||||
compiled_artifact: CompiledArtifact,
|
||||
stream,
|
||||
workspace=None,
|
||||
) -> None:
|
||||
stream = to_cuda_stream(stream)
|
||||
compiled_gemm = compiled_artifact.compiled_obj
|
||||
self.cute_run(
|
||||
compiled_gemm,
|
||||
args.A.tensor,
|
||||
args.B.tensor,
|
||||
args.out.tensor,
|
||||
args.offsets.tensor,
|
||||
stream,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _valid_operands(operands: GroupedGemmOperandsMetadata) -> bool:
|
||||
if operands.accumulator_type != cutlass.Float32:
|
||||
return False
|
||||
|
||||
if operands.A.stride[-2:].index(1) != 1:
|
||||
# A must be k-major
|
||||
return False
|
||||
if operands.B.stride[-2:].index(1) != 0:
|
||||
# B must be n-major
|
||||
return False
|
||||
if operands.out.stride[-2:].index(1) != 1:
|
||||
# out must be n-major
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _valid_design_metadata(design: DesignMetadata) -> bool:
|
||||
use_2cta_mma = design.use_2cta_mma
|
||||
mma_tiler_mn = design.tile_shape[:2]
|
||||
cluster_shape_mn = design.cluster_shape[:2]
|
||||
impl = ContiguousOffset2D3DGemmDenseKernelImpl
|
||||
return impl.is_valid_mma_tiler_and_cluster_shape(
|
||||
use_2cta_mma, mma_tiler_mn, cluster_shape_mn
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _valid_epilogue_metadata(epilogue: EpilogueMetadata | None) -> bool:
|
||||
return epilogue is None
|
||||
|
||||
@staticmethod
|
||||
def _valid_metadata(metadata: KernelMetadata) -> bool:
|
||||
if not ContiguousOffset2D3DGemmDenseKernel._valid_operands(metadata.operands):
|
||||
return False
|
||||
if not ContiguousOffset2D3DGemmDenseKernel._valid_design_metadata(
|
||||
metadata.design
|
||||
):
|
||||
return False
|
||||
if not ContiguousOffset2D3DGemmDenseKernel._valid_epilogue_metadata(
|
||||
metadata.epilogue
|
||||
):
|
||||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _metadata_operand_combinations() -> Generator[
|
||||
tuple[GroupedGemmOperandsMetadata, int], None, None
|
||||
]:
|
||||
"""
|
||||
Generator that yields all valid (GroupedGemmOperandsMetadata, sf_vec_size) combinations
|
||||
based on the validation rules in _valid_operands.
|
||||
"""
|
||||
# Supported A/B data types (must be the same)
|
||||
ab_dtypes = [
|
||||
cutlass.Float8E5M2,
|
||||
cutlass.Float8E4M3FN,
|
||||
cutlass.Float16,
|
||||
cutlass.BFloat16,
|
||||
]
|
||||
|
||||
out_dtypes = [
|
||||
cutlass.Float32,
|
||||
cutlass.Float16,
|
||||
cutlass.BFloat16,
|
||||
]
|
||||
|
||||
row_major_stride = (0, 0, 1)
|
||||
col_major_stride = (0, 1, 0)
|
||||
alignment_bytes = 16
|
||||
offsets_alignment_bytes = 4
|
||||
|
||||
stride_A = row_major_stride
|
||||
stride_B = col_major_stride
|
||||
stride_out = row_major_stride
|
||||
|
||||
acc_dtype = cutlass.Float32
|
||||
Impl = ContiguousOffset2D3DGemmDenseKernelImpl
|
||||
|
||||
for ab_dtype, out_dtype in itertools.product(ab_dtypes, out_dtypes):
|
||||
if not Impl.is_valid_dtypes(ab_dtype, acc_dtype, out_dtype):
|
||||
continue
|
||||
|
||||
ab_divisibility = alignment_bytes * 8 // ab_dtype.width
|
||||
out_divisibility = alignment_bytes * 8 // out_dtype.width
|
||||
offsets_divisibility = offsets_alignment_bytes * 8 // cutlass.Int32.width
|
||||
|
||||
operands = GroupedGemmOperandsMetadata(
|
||||
A=cutlass_api.metadata.DenseTensorAttributes(
|
||||
dtype=ab_dtype,
|
||||
stride=stride_A,
|
||||
divisibility=ab_divisibility,
|
||||
),
|
||||
B=cutlass_api.metadata.DenseTensorAttributes(
|
||||
dtype=ab_dtype,
|
||||
stride=stride_B,
|
||||
divisibility=ab_divisibility,
|
||||
),
|
||||
out=cutlass_api.metadata.DenseTensorAttributes(
|
||||
dtype=out_dtype,
|
||||
stride=stride_out,
|
||||
divisibility=out_divisibility,
|
||||
),
|
||||
offsets=cutlass_api.metadata.DenseTensorAttributes(
|
||||
dtype=cutlass.Int32,
|
||||
stride=(1,),
|
||||
divisibility=offsets_divisibility,
|
||||
),
|
||||
accumulator_type=acc_dtype,
|
||||
)
|
||||
|
||||
yield operands
|
||||
|
||||
@staticmethod
|
||||
def generate_kernels(
|
||||
metadata_filter: Callable[[KernelMetadata], bool],
|
||||
epilogue_args: EpilogueArguments = None,
|
||||
cc: int = None,
|
||||
) -> list["ContiguousOffset2D3DGemmDenseKernel"]:
|
||||
"""
|
||||
Returns a list of all possible configurations of ContiguousOffset2D3DGemmDenseKernel that
|
||||
adhere to constraints passed in under kwargs.
|
||||
"""
|
||||
if cc is not None and cc not in [100, 101, 103]:
|
||||
return []
|
||||
|
||||
design_params = {
|
||||
"use_2cta_mma": [True, False],
|
||||
"tile_shape": [
|
||||
(64, 128, 128),
|
||||
(128, 128, 128),
|
||||
(256, 128, 128),
|
||||
(256, 256, 128),
|
||||
],
|
||||
"cluster_shape": [(M, N, 1) for M in [1, 2, 4] for N in [1, 2, 4]],
|
||||
"use_tma_store": [True],
|
||||
}
|
||||
|
||||
if epilogue_args is not None:
|
||||
return []
|
||||
|
||||
from itertools import product
|
||||
|
||||
# Get the list of tunable parameter names and their possible values
|
||||
param_names = list(design_params.keys())
|
||||
param_values = [design_params[name] for name in param_names]
|
||||
|
||||
kernel_list = []
|
||||
|
||||
for (
|
||||
operands
|
||||
) in ContiguousOffset2D3DGemmDenseKernel._metadata_operand_combinations():
|
||||
for values in product(*param_values):
|
||||
design = Sm100DesignMetadata(**dict(zip(param_names, values)))
|
||||
kernel_name = "cutedsl.ContiguousOffset2D3DGemmDenseKernel_sm100_{layout}_A{A}_B{B}_out{out}_acc{acc}_{num_cta}cta_cluster{cluster}_tile{tile}{_tma_store}".format(
|
||||
layout=strides_to_layout_string(
|
||||
operands.A.stride, operands.B.stride, operands.out.stride
|
||||
),
|
||||
A=operands.A.dtype,
|
||||
B=operands.B.dtype,
|
||||
out=operands.out.dtype,
|
||||
acc=operands.accumulator_type,
|
||||
num_cta=("2" if design.use_2cta_mma else "1"),
|
||||
cluster=tuple_to_string(design.cluster_shape),
|
||||
tile=tuple_to_string(design.tile_shape),
|
||||
_tma_store="_tma_store" if design.use_tma_store else "",
|
||||
)
|
||||
|
||||
metadata = KernelMetadata(
|
||||
operands=operands,
|
||||
design=design,
|
||||
kernel_name=kernel_name,
|
||||
kernel_class=ContiguousOffset2D3DGemmDenseKernel,
|
||||
min_cc=100,
|
||||
epilogue=None,
|
||||
)
|
||||
|
||||
metadata_valid = ContiguousOffset2D3DGemmDenseKernel._valid_metadata(
|
||||
metadata
|
||||
)
|
||||
if metadata_valid and metadata_filter(metadata):
|
||||
kernel_list.append(ContiguousOffset2D3DGemmDenseKernel(metadata))
|
||||
|
||||
return kernel_list
|
||||
@@ -0,0 +1,418 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import itertools
|
||||
from collections.abc import Callable, Generator
|
||||
|
||||
import cutlass
|
||||
import cutlass.utils as utils
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cute.nvgpu.tcgen05 import OperandMajorMode
|
||||
|
||||
from cutlass.cute.runtime import make_ptr
|
||||
|
||||
import cutlass_api
|
||||
from cutlass_api.arguments import (
|
||||
EpilogueArguments,
|
||||
GemmArguments,
|
||||
)
|
||||
from cutlass_api.artifact import CompiledArtifact
|
||||
from cutlass_api.metadata import (
|
||||
DenseTensorAttributes,
|
||||
DesignMetadata,
|
||||
EpilogueMetadata,
|
||||
GemmOperandsMetadata,
|
||||
KernelMetadata,
|
||||
Sm100DesignMetadata,
|
||||
ScaledTensorAttributes,
|
||||
)
|
||||
from cutlass_api.library import ScaleMode, ScaleSwizzleMode
|
||||
from cutlass_api.providers.cutedsl import CuTeDSLProvider
|
||||
from cutlass_api.providers.cutedsl.kernel import CuteDslKernel
|
||||
from cutlass_api.providers.cutedsl.utils import get_max_active_clusters
|
||||
from cutlass_api.status import Status
|
||||
from cutlass_api.utils import ceil_div, round_up, strides_to_layout_string, to_cuda_stream, tuple_to_string
|
||||
|
||||
from .implementations.sm100_dense_blockscaled_static_persistent_impl import PersistentDenseBlockScaledGemmKernelImpl
|
||||
|
||||
|
||||
@CuTeDSLProvider.register
|
||||
class PersistentDenseBlockScaledGemmKernel(CuteDslKernel):
|
||||
def __init__(self, metadata: KernelMetadata):
|
||||
super().__init__(metadata)
|
||||
self.sf_vec_size = metadata.operands.A.mode[-1]
|
||||
mma_tiler_mn = (metadata.design.tile_shape[0], metadata.design.tile_shape[1])
|
||||
cluster_shape_mn = (
|
||||
metadata.design.cluster_shape[0],
|
||||
metadata.design.cluster_shape[1],
|
||||
)
|
||||
self.impl = PersistentDenseBlockScaledGemmKernelImpl(
|
||||
self.sf_vec_size,
|
||||
mma_tiler_mn,
|
||||
cluster_shape_mn,
|
||||
)
|
||||
|
||||
def _supports(self, args: GemmArguments) -> Status:
|
||||
M, N = args.out.shape[-2:]
|
||||
K = args.A.shape[-1]
|
||||
L = args.A.shape[0] if len(args.A.shape) == 3 else 1
|
||||
|
||||
# To support 32-4-4 swizzling, the scale factor tensor must be padded to a multiple of 4
|
||||
expected_sf_k = round_up(ceil_div(K, self.sf_vec_size), 4)
|
||||
|
||||
# Shapes of scale factor tensors are not enforced. It is expected that the
|
||||
# data underlying the passed in scale factor tensor is in 32-4-4 swizzled format.
|
||||
# Thus, we check only that the number of elements in the scale factor tensor is correct.
|
||||
expected_sfa_elements = L * M * expected_sf_k
|
||||
expected_sfb_elements = L * N * expected_sf_k
|
||||
|
||||
if args.A.scale.numel() != expected_sfa_elements:
|
||||
return Status.fail(
|
||||
f"Scale factor A for tensor A of shape {args.A.shape} must have "
|
||||
f"{expected_sfa_elements} elements. Scale factor A is of shape {args.A.scale.shape} "
|
||||
f"and has {args.A.scale.numel()} elements."
|
||||
)
|
||||
if args.B.scale.numel() != expected_sfb_elements:
|
||||
return Status.fail(
|
||||
f"Scale factor B for tensor B of shape {args.B.shape} must have "
|
||||
f"{expected_sfb_elements} elements. Scale factor B is of shape {args.B.scale.shape} "
|
||||
f"and has {args.B.scale.numel()} elements."
|
||||
)
|
||||
|
||||
return Status.success()
|
||||
|
||||
def _construct_pointers(self, args: GemmArguments, nullptr: bool = False) -> tuple[cute.Pointer, cute.Pointer, cute.Pointer, cute.Pointer, cute.Pointer]:
|
||||
def ptr(x): return 0 if nullptr else x.data_ptr
|
||||
gmem = cute.AddressSpace.gmem
|
||||
a_ptr = make_ptr(args.A.element_type, ptr(args.A), gmem, assumed_align=16)
|
||||
b_ptr = make_ptr(args.B.element_type, ptr(args.B), gmem, assumed_align=16)
|
||||
out_ptr = make_ptr(args.out.element_type, ptr(args.out), gmem, assumed_align=16)
|
||||
sfa_ptr = make_ptr(
|
||||
args.A.scale.element_type, ptr(args.A.scale), gmem, assumed_align=32
|
||||
)
|
||||
sfb_ptr = make_ptr(
|
||||
args.B.scale.element_type, ptr(args.B.scale), gmem, assumed_align=32
|
||||
)
|
||||
return a_ptr, b_ptr, out_ptr, sfa_ptr, sfb_ptr
|
||||
|
||||
@staticmethod
|
||||
def _major_modes(args: GemmArguments | GemmOperandsMetadata) -> tuple[tuple[OperandMajorMode, str], tuple[OperandMajorMode, str], tuple[utils.LayoutEnum, str]]:
|
||||
# A, B, and out can be of rank 2 or 3. Extract the final two dimensions
|
||||
# to determine the major mode
|
||||
if args.A.stride[-2:].index(1) == 1:
|
||||
a_major_mode = (OperandMajorMode.K, "k")
|
||||
else:
|
||||
a_major_mode = (OperandMajorMode.MN, "m")
|
||||
|
||||
if args.B.stride[-2:].index(1) == 0:
|
||||
b_major_mode = (OperandMajorMode.K, "k")
|
||||
else:
|
||||
b_major_mode = (OperandMajorMode.MN, "n")
|
||||
|
||||
if args.out.stride[-2:].index(1) == 1:
|
||||
out_layout = (utils.LayoutEnum.ROW_MAJOR, "n")
|
||||
else:
|
||||
out_layout = (utils.LayoutEnum.COL_MAJOR, "m")
|
||||
|
||||
return a_major_mode, b_major_mode, out_layout
|
||||
|
||||
def compile(self, args: GemmArguments, cc: int = None) -> CompiledArtifact:
|
||||
stream = cutlass.cute.runtime.make_fake_stream()
|
||||
max_active_clusters = get_max_active_clusters(self.impl.cluster_shape_mn)
|
||||
|
||||
a_ptr, b_ptr, out_ptr, sfa_ptr, sfb_ptr = self._construct_pointers(args, nullptr=True)
|
||||
fake_problem_shape = (cutlass.Int32(0), cutlass.Int32(0), cutlass.Int32(0), cutlass.Int32(0))
|
||||
epilogue_op = lambda x: x
|
||||
|
||||
(a_major_mode, _), (b_major_mode, _), (out_layout, _) = PersistentDenseBlockScaledGemmKernel._major_modes(args)
|
||||
|
||||
compiled_kernel = self.cute_compile(
|
||||
self.impl,
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
sfa_ptr,
|
||||
sfb_ptr,
|
||||
out_ptr,
|
||||
(a_major_mode, b_major_mode, out_layout),
|
||||
fake_problem_shape,
|
||||
max_active_clusters,
|
||||
stream,
|
||||
epilogue_op,
|
||||
)
|
||||
return CompiledArtifact(compiled_kernel, self)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
args: GemmArguments,
|
||||
compiled_artifact: CompiledArtifact,
|
||||
stream,
|
||||
workspace=None,
|
||||
) -> None:
|
||||
stream = to_cuda_stream(stream)
|
||||
compiled_gemm = compiled_artifact.compiled_obj
|
||||
|
||||
m, n = args.out.shape[-2:]
|
||||
k = args.A.shape[-1]
|
||||
l = args.A.shape[0] if len(args.A.shape) == 3 else 1
|
||||
|
||||
a_ptr, b_ptr, out_ptr, sfa_ptr, sfb_ptr = self._construct_pointers(args)
|
||||
self.cute_run(compiled_gemm, a_ptr, b_ptr, sfa_ptr, sfb_ptr, out_ptr, (m, n, k, l), stream)
|
||||
|
||||
@staticmethod
|
||||
def _valid_operands(operands: GemmOperandsMetadata, sf_vec_size: int) -> bool:
|
||||
if not isinstance(operands, GemmOperandsMetadata):
|
||||
return False
|
||||
|
||||
# In current version, A and B tensor must have the same data type
|
||||
# i.e., Float8E4M3FN for A and Float8E5M2 for B is not supported.
|
||||
# The same holds for scale A and scale B.
|
||||
if (
|
||||
operands.A.dtype != operands.B.dtype
|
||||
or operands.A.scale.dtype != operands.B.scale.dtype
|
||||
):
|
||||
return False
|
||||
|
||||
if operands.accumulator_type != cutlass.Float32:
|
||||
return False
|
||||
|
||||
ab_dtype = operands.A.dtype
|
||||
sf_dtype = operands.A.scale.dtype
|
||||
out_dtype = operands.out.dtype
|
||||
|
||||
impl = PersistentDenseBlockScaledGemmKernelImpl
|
||||
if not impl.is_valid_dtypes_and_scale_factor_vec_size(ab_dtype, sf_dtype, sf_vec_size, out_dtype):
|
||||
return False
|
||||
|
||||
(_, a_major), (_, b_major), (_, out_major) = PersistentDenseBlockScaledGemmKernel._major_modes(operands)
|
||||
if not impl.is_valid_layouts(ab_dtype, out_dtype, a_major, b_major, out_major):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _valid_design_metadata(design: DesignMetadata) -> bool:
|
||||
if not isinstance(design, Sm100DesignMetadata):
|
||||
return False
|
||||
|
||||
mma_tiler_mn = design.tile_shape[:2]
|
||||
cluster_shape_mn = design.cluster_shape[:2]
|
||||
impl = PersistentDenseBlockScaledGemmKernelImpl
|
||||
return impl.is_valid_mma_tiler_and_cluster_shape(mma_tiler_mn, cluster_shape_mn)
|
||||
|
||||
@staticmethod
|
||||
def _valid_epilogue_metadata(epilogue: EpilogueMetadata | None) -> bool:
|
||||
return epilogue is None
|
||||
|
||||
@staticmethod
|
||||
def _valid_metadata(metadata: KernelMetadata) -> bool:
|
||||
scale_vec = metadata.operands.A.mode
|
||||
|
||||
# Make sure scale vector is in the form (1, 1, ..., 1, sf_vec_size)
|
||||
if len(scale_vec) > 1:
|
||||
for i in range(0, len(scale_vec) - 1):
|
||||
if scale_vec[i] != 1:
|
||||
return False
|
||||
|
||||
sf_vec_size = scale_vec[-1]
|
||||
if not PersistentDenseBlockScaledGemmKernel._valid_operands(metadata.operands, sf_vec_size):
|
||||
return False
|
||||
if not PersistentDenseBlockScaledGemmKernel._valid_design_metadata(metadata.design):
|
||||
return False
|
||||
if not PersistentDenseBlockScaledGemmKernel._valid_epilogue_metadata(metadata.epilogue):
|
||||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _metadata_operand_combinations() -> Generator[tuple[GemmOperandsMetadata, int], None, None]:
|
||||
"""
|
||||
Generator that yields all valid (GemmOperandsMetadata, sf_vec_size) combinations
|
||||
based on the validation rules in _valid_operands.
|
||||
"""
|
||||
# Supported A/B data types (must be the same)
|
||||
ab_dtypes = [
|
||||
cutlass.Float8E5M2,
|
||||
cutlass.Float8E4M3FN,
|
||||
]
|
||||
|
||||
out_dtypes = [
|
||||
cutlass.Float32,
|
||||
cutlass.Float16,
|
||||
cutlass.BFloat16,
|
||||
cutlass.Float8E4M3FN,
|
||||
cutlass.Float8E5M2,
|
||||
]
|
||||
|
||||
sf_dtypes = [
|
||||
cutlass.Float8E8M0FNU,
|
||||
cutlass.Float8E4M3FN,
|
||||
]
|
||||
|
||||
scale_modes = [ScaleMode.Blockwise1x16, ScaleMode.Blockwise1x32]
|
||||
|
||||
row_major_stride = (0, 0, 1)
|
||||
col_major_stride = (0, 1, 0)
|
||||
alignment_bytes = 16
|
||||
|
||||
def major_str_a(major_tuple: tuple[int, int, int]) -> str:
|
||||
return "k" if major_tuple == row_major_stride else "m"
|
||||
def major_str_b(major_tuple: tuple[int, int, int]) -> str:
|
||||
return "n" if major_tuple == row_major_stride else "k"
|
||||
def major_str_out(major_tuple: tuple[int, int, int]) -> str:
|
||||
return "n" if major_tuple == row_major_stride else "m"
|
||||
|
||||
Impl = PersistentDenseBlockScaledGemmKernelImpl
|
||||
|
||||
for ab_dtype, sf_dtype, scale_mode, out_dtype in itertools.product(ab_dtypes, sf_dtypes, scale_modes, out_dtypes):
|
||||
sf_vec_size = scale_mode[-1]
|
||||
if not Impl.is_valid_dtypes_and_scale_factor_vec_size(ab_dtype, sf_dtype, sf_vec_size, out_dtype):
|
||||
continue
|
||||
|
||||
for stride_A, stride_B, stride_out in itertools.product(
|
||||
[row_major_stride, col_major_stride], repeat=3
|
||||
):
|
||||
a_major = major_str_a(stride_A)
|
||||
b_major = major_str_b(stride_B)
|
||||
out_major = major_str_out(stride_out)
|
||||
|
||||
if not Impl.is_valid_layouts(ab_dtype, out_dtype, a_major, b_major, out_major):
|
||||
continue
|
||||
|
||||
ab_divisibility = alignment_bytes * 8 // ab_dtype.width
|
||||
out_divisibility = alignment_bytes * 8 // out_dtype.width
|
||||
sf_divisibility = alignment_bytes * 8 // sf_dtype.width
|
||||
|
||||
operands = GemmOperandsMetadata(
|
||||
A=ScaledTensorAttributes(
|
||||
base=DenseTensorAttributes(
|
||||
dtype=ab_dtype,
|
||||
stride=stride_A,
|
||||
divisibility=ab_divisibility,
|
||||
),
|
||||
scale=DenseTensorAttributes(
|
||||
dtype=sf_dtype,
|
||||
stride=None,
|
||||
divisibility=sf_divisibility,
|
||||
),
|
||||
mode=scale_mode,
|
||||
swizzle=ScaleSwizzleMode.Swizzle32x4x4,
|
||||
),
|
||||
B=ScaledTensorAttributes(
|
||||
base=DenseTensorAttributes(
|
||||
dtype=ab_dtype,
|
||||
stride=stride_B,
|
||||
divisibility=ab_divisibility,
|
||||
),
|
||||
scale=DenseTensorAttributes(
|
||||
dtype=sf_dtype,
|
||||
stride=None,
|
||||
divisibility=sf_divisibility,
|
||||
),
|
||||
mode=scale_mode,
|
||||
swizzle=ScaleSwizzleMode.Swizzle32x4x4,
|
||||
),
|
||||
out=DenseTensorAttributes(
|
||||
dtype=out_dtype,
|
||||
stride=stride_out,
|
||||
divisibility=out_divisibility,
|
||||
),
|
||||
accumulator_type=cutlass.Float32,
|
||||
)
|
||||
|
||||
yield operands
|
||||
|
||||
@staticmethod
|
||||
def generate_kernels(
|
||||
metadata_filter: Callable[[KernelMetadata], bool],
|
||||
epilogue_args: EpilogueArguments = None,
|
||||
cc: int = None,
|
||||
) -> list["PersistentDenseBlockScaledGemmKernel"]:
|
||||
"""
|
||||
Returns a list of all possible configurations of PersistentDenseBlockScaledGemmKernel that
|
||||
adhere to constraints passed in under kwargs.
|
||||
"""
|
||||
if cc is not None and cc not in [100, 101, 103]:
|
||||
return []
|
||||
|
||||
design_params = {
|
||||
"use_2cta_mma": [True],
|
||||
"tile_shape": [
|
||||
(M, N, 256) for M in [128, 256] for N in [64, 128, 192, 256]
|
||||
],
|
||||
"cluster_shape": [
|
||||
(M, N, 1) for M in [2, 4] for N in [1, 2, 4]
|
||||
],
|
||||
"use_tma_store": [True],
|
||||
}
|
||||
|
||||
if epilogue_args is not None:
|
||||
return []
|
||||
|
||||
from itertools import product
|
||||
|
||||
# Get the list of tunable parameter names and their possible values
|
||||
param_names = list(design_params.keys())
|
||||
param_values = [design_params[name] for name in param_names]
|
||||
|
||||
kernel_list = []
|
||||
|
||||
for operands in PersistentDenseBlockScaledGemmKernel._metadata_operand_combinations():
|
||||
for values in product(*param_values):
|
||||
design = Sm100DesignMetadata(**dict(zip(param_names, values)))
|
||||
kernel_name = "cutedsl.PersistentDenseBlockScaledGemmKernel_sm100_{layout}_A{A}_B{B}_out{out}_SFA{SFA}_SFB{SFB}_acc{acc}_scale{scale_mode}_swizzle{scale_swizzle}_{num_cta}cta_cluster{cluster}_tile{tile}{_tma_store}".format(
|
||||
layout=strides_to_layout_string(
|
||||
operands.A.stride, operands.B.stride, operands.out.stride
|
||||
),
|
||||
A=operands.A.dtype,
|
||||
B=operands.B.dtype,
|
||||
out=operands.out.dtype,
|
||||
SFA=operands.A.scale.dtype,
|
||||
SFB=operands.B.scale.dtype,
|
||||
acc=operands.accumulator_type,
|
||||
scale_mode=operands.A.mode,
|
||||
scale_swizzle=operands.A.swizzle,
|
||||
num_cta=("2" if design.use_2cta_mma else "1"),
|
||||
cluster=tuple_to_string(design.cluster_shape),
|
||||
tile=tuple_to_string(design.tile_shape),
|
||||
_tma_store="_tma_store" if design.use_tma_store else "",
|
||||
)
|
||||
|
||||
metadata = KernelMetadata(
|
||||
operands=operands,
|
||||
design=design,
|
||||
kernel_name=kernel_name,
|
||||
kernel_class=PersistentDenseBlockScaledGemmKernel,
|
||||
min_cc=100,
|
||||
epilogue=None,
|
||||
)
|
||||
|
||||
metadata_valid = PersistentDenseBlockScaledGemmKernel._valid_metadata(metadata)
|
||||
|
||||
if metadata_valid and metadata_filter(metadata):
|
||||
kernel_list.append(PersistentDenseBlockScaledGemmKernel(metadata))
|
||||
|
||||
return kernel_list
|
||||
@@ -37,10 +37,10 @@ from cutlass_api.arguments import (
|
||||
)
|
||||
from cutlass_api.artifact import CompiledArtifact
|
||||
from cutlass_api.metadata import (
|
||||
DenseTensorAttributes,
|
||||
GemmOperandsMetadata,
|
||||
KernelMetadata,
|
||||
Sm100DesignMetadata,
|
||||
TensorAttributes,
|
||||
)
|
||||
from cutlass_api.providers.cutedsl import CuTeDSLProvider
|
||||
from cutlass_api.providers.cutedsl.kernel import CuteDslKernel
|
||||
@@ -104,21 +104,15 @@ class PersistentDenseGemmKernel(CuteDslKernel):
|
||||
epilogue_op,
|
||||
)
|
||||
|
||||
def _supports(self, args: GemmArguments) -> Status:
|
||||
|
||||
if args.epilogue is not None:
|
||||
return Status.fail("This kernel does not support any epilogue fusion.")
|
||||
return Status.success()
|
||||
|
||||
def compile(self, args: GemmArguments, cc: int = None) -> CompiledArtifact:
|
||||
stream = cutlass.cute.runtime.make_fake_stream()
|
||||
max_active_clusters = get_max_active_clusters(self.impl.cluster_shape_mn)
|
||||
|
||||
compiled_kernel = self.cute_compile(
|
||||
self.impl,
|
||||
args.A,
|
||||
args.B,
|
||||
args.out,
|
||||
args.A.tensor,
|
||||
args.B.tensor,
|
||||
args.out.tensor,
|
||||
max_active_clusters,
|
||||
stream,
|
||||
self.impl.epilogue_op,
|
||||
@@ -134,7 +128,9 @@ class PersistentDenseGemmKernel(CuteDslKernel):
|
||||
) -> None:
|
||||
stream = to_cuda_stream(stream)
|
||||
compiled_gemm = compiled_artifact.compiled_obj
|
||||
self.cute_run(compiled_gemm, args.A, args.B, args.out, stream)
|
||||
self.cute_run(
|
||||
compiled_gemm, args.A.tensor, args.B.tensor, args.out.tensor, stream
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _valid_operands(operands: GemmOperandsMetadata) -> bool:
|
||||
@@ -279,17 +275,17 @@ class PersistentDenseGemmKernel(CuteDslKernel):
|
||||
ab_divisibility = alignment_bytes * 8 // ab_dtype.width
|
||||
out_divisibility = alignment_bytes * 8 // out_dtype.width
|
||||
# Create TensorAttributes for A, B, and out tensors
|
||||
a_attrs = TensorAttributes(
|
||||
a_attrs = DenseTensorAttributes(
|
||||
dtype=ab_dtype,
|
||||
stride=stride_A,
|
||||
divisibility=ab_divisibility,
|
||||
)
|
||||
b_attrs = TensorAttributes(
|
||||
b_attrs = DenseTensorAttributes(
|
||||
dtype=ab_dtype,
|
||||
stride=stride_B,
|
||||
divisibility=ab_divisibility,
|
||||
)
|
||||
out_attrs = TensorAttributes(
|
||||
out_attrs = DenseTensorAttributes(
|
||||
dtype=out_dtype,
|
||||
stride=stride_out,
|
||||
divisibility=out_divisibility,
|
||||
|
||||
@@ -37,11 +37,11 @@ from cutlass_api.arguments import (
|
||||
)
|
||||
from cutlass_api.artifact import CompiledArtifact
|
||||
from cutlass_api.metadata import (
|
||||
DenseTensorAttributes,
|
||||
EpilogueMetadata,
|
||||
GemmOperandsMetadata,
|
||||
KernelMetadata,
|
||||
Sm100DesignMetadata,
|
||||
TensorAttributes,
|
||||
)
|
||||
from cutlass_api.providers.cutedsl import CuTeDSLProvider
|
||||
from cutlass_api.providers.cutedsl.evt.converter import EFCConverter
|
||||
@@ -284,17 +284,17 @@ class PersistentDenseGemmEFCKernel(CuteDslKernel):
|
||||
ab_divisibility = alignment_bytes * 8 // ab_dtype.width
|
||||
out_divisibility = alignment_bytes * 8 // out_dtype.width
|
||||
# Create TensorAttributes for A, B, and out tensors
|
||||
a_attrs = TensorAttributes(
|
||||
a_attrs = DenseTensorAttributes(
|
||||
dtype=ab_dtype,
|
||||
stride=stride_A,
|
||||
divisibility=ab_divisibility,
|
||||
)
|
||||
b_attrs = TensorAttributes(
|
||||
b_attrs = DenseTensorAttributes(
|
||||
dtype=ab_dtype,
|
||||
stride=stride_B,
|
||||
divisibility=ab_divisibility,
|
||||
)
|
||||
out_attrs = TensorAttributes(
|
||||
out_attrs = DenseTensorAttributes(
|
||||
dtype=out_dtype,
|
||||
stride=stride_out,
|
||||
divisibility=out_divisibility,
|
||||
@@ -311,7 +311,6 @@ class PersistentDenseGemmEFCKernel(CuteDslKernel):
|
||||
yield operands
|
||||
|
||||
def _supports(self, args: GemmArguments) -> Status:
|
||||
|
||||
if args.epilogue is not None:
|
||||
fusion_metadata = EpilogueMetadata.from_args(args.epilogue)
|
||||
if not self._valid_fusion(fusion_metadata):
|
||||
@@ -339,8 +338,8 @@ class PersistentDenseGemmEFCKernel(CuteDslKernel):
|
||||
self.impl.efc.compile(epilogue_params)
|
||||
compiled_gemm = self.cute_compile(
|
||||
self.impl,
|
||||
args.A,
|
||||
args.B,
|
||||
args.A.tensor,
|
||||
args.B.tensor,
|
||||
max_active_clusters,
|
||||
stream,
|
||||
self.impl.efc.jit.pack_arguments(*epilogue_params),
|
||||
@@ -376,7 +375,9 @@ class PersistentDenseGemmEFCKernel(CuteDslKernel):
|
||||
epilogue_params = [args.out]
|
||||
|
||||
compiled_gemm = compiled_artifact.compiled_obj
|
||||
self.cute_run(compiled_gemm, args.A, args.B, stream, *epilogue_params)
|
||||
self.cute_run(
|
||||
compiled_gemm, args.A.tensor, args.B.tensor, stream, *epilogue_params
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def generate_kernels(
|
||||
|
||||
@@ -37,15 +37,16 @@ from cutlass_api.arguments import (
|
||||
)
|
||||
from cutlass_api.artifact import CompiledArtifact
|
||||
from cutlass_api.metadata import (
|
||||
DenseTensorAttributes,
|
||||
GemmOperandsMetadata,
|
||||
KernelMetadata,
|
||||
TensorAttributes,
|
||||
)
|
||||
from cutlass_api.providers.cutedsl import CuTeDSLProvider
|
||||
from cutlass_api.providers.cutedsl.kernel import CuteDslKernel
|
||||
from cutlass_api.status import Status
|
||||
from cutlass_api.utils import strides_to_layout_string, to_cuda_stream, tuple_to_string
|
||||
from cutlass_api.metadata import BLASDesignMetadata
|
||||
|
||||
from .implementations.sm80_tensorop_gemm_impl import Sm80TensorOpGemmImpl
|
||||
|
||||
|
||||
@@ -77,20 +78,14 @@ class Sm80TensorOpGemmKernel(CuteDslKernel):
|
||||
metadata.operands.accumulator_type
|
||||
)
|
||||
|
||||
def _supports(self, args: GemmArguments) -> Status:
|
||||
|
||||
if args.epilogue is not None:
|
||||
return Status.fail("This kernel does not support any epilogue fusion.")
|
||||
return Status.success()
|
||||
|
||||
def compile(self, args: GemmArguments, cc: int = None) -> CompiledArtifact:
|
||||
stream = cutlass.cute.runtime.make_fake_stream()
|
||||
|
||||
compiled_kernel = self.cute_compile(
|
||||
self.impl,
|
||||
args.A,
|
||||
args.B,
|
||||
args.out,
|
||||
args.A.tensor,
|
||||
args.B.tensor,
|
||||
args.out.tensor,
|
||||
stream,
|
||||
)
|
||||
return CompiledArtifact(compiled_kernel, self)
|
||||
@@ -104,7 +99,9 @@ class Sm80TensorOpGemmKernel(CuteDslKernel):
|
||||
) -> None:
|
||||
stream = to_cuda_stream(stream)
|
||||
compiled_gemm = compiled_artifact.compiled_obj
|
||||
self.cute_run(compiled_gemm, args.A, args.B, args.out, stream)
|
||||
self.cute_run(
|
||||
compiled_gemm, args.A.tensor, args.B.tensor, args.out.tensor, stream
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _valid_operands(operands: GemmOperandsMetadata) -> bool:
|
||||
@@ -155,17 +152,17 @@ class Sm80TensorOpGemmKernel(CuteDslKernel):
|
||||
out_divisibility = alignment_bytes * 8 // out_dtype.width
|
||||
|
||||
# Create TensorAttributes for A, B, and out tensors
|
||||
a_attrs = TensorAttributes(
|
||||
a_attrs = DenseTensorAttributes(
|
||||
dtype=ab_dtype,
|
||||
stride=stride_A,
|
||||
divisibility=ab_divisibility,
|
||||
)
|
||||
b_attrs = TensorAttributes(
|
||||
b_attrs = DenseTensorAttributes(
|
||||
dtype=ab_dtype,
|
||||
stride=stride_B,
|
||||
divisibility=ab_divisibility,
|
||||
)
|
||||
out_attrs = TensorAttributes(
|
||||
out_attrs = DenseTensorAttributes(
|
||||
dtype=out_dtype,
|
||||
stride=stride_out,
|
||||
divisibility=out_divisibility,
|
||||
|
||||
@@ -0,0 +1,75 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
from cutlass_api.arguments import RuntimeArguments
|
||||
from cutlass_api.status import Status
|
||||
|
||||
|
||||
def tensor_rank_2_or_3(args: RuntimeArguments) -> Status:
|
||||
"""
|
||||
Checks that the arguments are valid for a tensor of rank 2 or 3.
|
||||
"""
|
||||
if len(args.A.shape) < 2 or len(args.A.shape) > 3:
|
||||
return Status.fail(
|
||||
f"A must be a tensor of rank 2 or 3 (L=1, M, K), got {args.A.shape}"
|
||||
)
|
||||
if len(args.B.shape) < 2 or len(args.B.shape) > 3:
|
||||
return Status.fail(
|
||||
f"B must be a tensor of rank 2 or 3 (L=1, K, N), got {len(args.B.shape)}"
|
||||
)
|
||||
if len(args.out.shape) < 2 or len(args.out.shape) > 3:
|
||||
return Status.fail(
|
||||
f"out must be a tensor of rank 2 or 3 (L=1, M, N), got {len(args.out.shape)}"
|
||||
)
|
||||
return Status.success()
|
||||
|
||||
|
||||
def tensor_reduction_mode_matches(args: RuntimeArguments) -> Status:
|
||||
if args.A.shape[-1] != args.B.shape[-2]:
|
||||
return Status.fail(
|
||||
f"A's K dimension ({args.A.shape[-1]}) must be equal to B's "
|
||||
f"K dimension ({args.B.shape[-2]}). "
|
||||
f"A shape (L, M, K): {args.A.shape}, B shape (L, K, N): {args.B.shape}"
|
||||
)
|
||||
return Status.success()
|
||||
|
||||
|
||||
def tensor_output_shape_matches(args: RuntimeArguments) -> Status:
|
||||
if args.out.shape[-2] != args.A.shape[-2]:
|
||||
return Status.fail(
|
||||
f"out's M dimension ({args.out.shape[-2]}) must be equal to A's "
|
||||
f"M dimension ({args.A.shape[-2]}). "
|
||||
f"A shape (L, M, K): {args.A.shape}, out shape (L, M, N): {args.out.shape}"
|
||||
)
|
||||
if args.out.shape[-1] != args.B.shape[-1]:
|
||||
return Status.fail(
|
||||
f"out's N dimension ({args.out.shape[-1]}) must be equal to B's "
|
||||
f"N dimension ({args.B.shape[-1]}). "
|
||||
f"B shape (L, K, N): {args.B.shape}, out shape (L, M, N): {args.out.shape}"
|
||||
)
|
||||
return Status.success()
|
||||
@@ -47,6 +47,36 @@ if TYPE_CHECKING:
|
||||
import torch
|
||||
|
||||
|
||||
def ceil_div(a: int, b: int) -> int:
|
||||
"""
|
||||
Compute the ceiling division of a by b.
|
||||
|
||||
:param a: The dividend.
|
||||
:type a: int
|
||||
:param b: The divisor.
|
||||
:type b: int
|
||||
|
||||
:return: The ceiling division of a by b.
|
||||
:rtype: int
|
||||
"""
|
||||
return (a + b - 1) // b
|
||||
|
||||
|
||||
def round_up(a: int, b: int) -> int:
|
||||
"""
|
||||
Round a up to the nearest multiple of b.
|
||||
|
||||
:param a: The value to round up.
|
||||
:type a: int
|
||||
:param b: The multiple to round up to.
|
||||
:type b: int
|
||||
|
||||
:return: The value of a rounded up to the nearest multiple of b.
|
||||
:rtype: int
|
||||
"""
|
||||
return ((a + b - 1) // b) * b
|
||||
|
||||
|
||||
def is_numpy_available() -> bool:
|
||||
"""Check if numpy is available."""
|
||||
return importlib.util.find_spec("numpy") is not None
|
||||
@@ -173,6 +203,7 @@ def cutlass_type_from_torch_type(dtype) -> type[cutlass.Numeric]:
|
||||
torch.float8_e5m2: cutlass.Float8E5M2,
|
||||
torch.float8_e4m3fn: cutlass.Float8E4M3FN,
|
||||
torch.float8_e4m3fnuz: cutlass.Float8E4M3B11FNUZ,
|
||||
torch.float8_e8m0fnu: cutlass.Float8E8M0FNU,
|
||||
}
|
||||
|
||||
try:
|
||||
@@ -236,52 +267,6 @@ def to_cuda_stream(
|
||||
raise ValueError(f"Unsupported stream type: {type(stream)}")
|
||||
|
||||
|
||||
def add_batch_mode(
|
||||
tensor: cute.Tensor | torch.Tensor,
|
||||
) -> cute.Tensor | torch.Tensor:
|
||||
"""
|
||||
Adds a batch mode to the tensor.
|
||||
If the tensor is a torch.Tensor and has rank 2,
|
||||
it will be unsqueezed along the first dimension.
|
||||
|
||||
:param tensor: The tensor to add batch mode to.
|
||||
:type tensor: Union[cute.Tensor, "torch.Tensor"]
|
||||
|
||||
:return: The tensor with batch mode added.
|
||||
:rtype: Union[cute.Tensor, "torch.Tensor"]
|
||||
"""
|
||||
if is_torch_tensor(tensor):
|
||||
if tensor.dim() == 2:
|
||||
return tensor.unsqueeze(0)
|
||||
elif tensor.dim() < 2 or tensor.dim() > 3:
|
||||
raise ValueError(f"Expected 2-3 dimensions, got {tensor.dim()}")
|
||||
return tensor
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
def permute_batch_mode(
|
||||
tensor: cute.Tensor | torch.Tensor,
|
||||
) -> cute.Tensor | torch.Tensor:
|
||||
"""
|
||||
Permute the batch mode of the tensor.
|
||||
If the tensor is a torch.Tensor and has rank 3,
|
||||
it will be permuted along the first dimension.
|
||||
|
||||
:param tensor: The tensor to permute.
|
||||
:type tensor: Union[cute.Tensor, "torch.Tensor"]
|
||||
|
||||
:return: The tensor with batch mode permuted.
|
||||
:rtype: Union[cute.Tensor, "torch.Tensor"]
|
||||
"""
|
||||
if is_torch_tensor(tensor):
|
||||
if tensor.dim() != 3:
|
||||
raise ValueError(f"Expected 3 dimensions, got {tensor.dim()}")
|
||||
return tensor.permute([1, 2, 0])
|
||||
else:
|
||||
raise ValueError(f"Unsupported type: {type(tensor)}")
|
||||
|
||||
|
||||
def leading_dim(tensor) -> int:
|
||||
"""
|
||||
Get the leading dimension of a tensor. This is the first mode with
|
||||
@@ -378,7 +363,7 @@ class TensorWrapper:
|
||||
is enabled.
|
||||
"""
|
||||
|
||||
def __init__(self, tensor: Any):
|
||||
def __init__(self, tensor: Any, alignment_bytes: int = 16):
|
||||
if isinstance(tensor, cute.Tensor):
|
||||
# Regardless of whether TVM-FFI is enabled, if the tensor passed in is a cute.Tensor,
|
||||
# it can be used as the runtime tensor and compile time tensor.
|
||||
@@ -399,7 +384,9 @@ class TensorWrapper:
|
||||
stride_order = get_stride_rank(self._stride)
|
||||
leading_dim_idx = stride_order.index(0)
|
||||
shape = [cute.SymInt() for _ in range(rank)]
|
||||
shape[leading_dim_idx] = cute.SymInt(divisibility=16 * 8 // dtype.width)
|
||||
shape[leading_dim_idx] = cute.SymInt(
|
||||
divisibility=alignment_bytes * 8 // dtype.width
|
||||
)
|
||||
self._shape = tuple(self.runtime_tensor.shape)
|
||||
self._data_ptr = self.runtime_tensor.data_ptr()
|
||||
else:
|
||||
@@ -411,7 +398,7 @@ class TensorWrapper:
|
||||
dtype,
|
||||
shape,
|
||||
stride_order=stride_order,
|
||||
assumed_align=16, # bytes
|
||||
assumed_align=alignment_bytes,
|
||||
)
|
||||
else:
|
||||
# TVM-FFI is disabled and the tensor passed in is not a cute.Tensor,
|
||||
@@ -425,12 +412,12 @@ class TensorWrapper:
|
||||
self.runtime_tensor = (
|
||||
from_dlpack(
|
||||
tensor,
|
||||
assumed_align=16, # bytes
|
||||
assumed_align=alignment_bytes,
|
||||
)
|
||||
.mark_layout_dynamic(leading_dim(tensor))
|
||||
.mark_compact_shape_dynamic(
|
||||
mode=leading_dim(tensor),
|
||||
divisibility=16 * 8 // dtype.width,
|
||||
divisibility=alignment_bytes * 8 // dtype.width,
|
||||
stride_order=stride_order,
|
||||
)
|
||||
)
|
||||
@@ -459,6 +446,12 @@ class TensorWrapper:
|
||||
def data_ptr(self) -> int:
|
||||
return self._data_ptr
|
||||
|
||||
def numel(self) -> int:
|
||||
num = self._shape[0]
|
||||
for i in range(1, len(self._shape)):
|
||||
num *= self._shape[i]
|
||||
return num
|
||||
|
||||
|
||||
def strides_to_layout_string(*strides: list[tuple[int, ...]]) -> str:
|
||||
"""
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
"id": "3dd45ef2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Basic GEMM using CUTLASS Python API"
|
||||
"# Basic GEMM using CUTLASS API"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -229,7 +229,7 @@
|
||||
" self, args: GemmArguments, cc: int = None\n",
|
||||
") -> cutlass_api.artifact.CompiledArtifact:\n",
|
||||
" stream = cutlass.cute.runtime.make_fake_stream()\n",
|
||||
" compiled_gemm = self.cute_compile(self.impl, args.A, args.B, args.out, stream)\n",
|
||||
" compiled_gemm = self.cute_compile(self.impl, args.A.tensor, args.B.tensor, args.out.tensor, stream)\n",
|
||||
" return cutlass_api.artifact.CompiledArtifact(compiled_gemm, self)"
|
||||
]
|
||||
},
|
||||
@@ -260,7 +260,7 @@
|
||||
"):\n",
|
||||
" stream = cutlass_api.utils.to_cuda_stream(stream)\n",
|
||||
" compiled_gemm = artifact.compiled_obj\n",
|
||||
" self.cute_run(compiled_gemm, args.A, args.B, args.out, stream)"
|
||||
" self.cute_run(compiled_gemm, args.A.tensor, args.B.tensor, args.out.tensor, stream)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -343,13 +343,13 @@
|
||||
" design_metadata = cutlass_api.metadata.BLASDesignMetadata(tile_shape, (1, 1, 1))\n",
|
||||
" for stride_A, stride_B, stride_out in stride_combos:\n",
|
||||
" # Create TensorAttributes for A, B, and out tensors\n",
|
||||
" a_attrs = cutlass_api.metadata.TensorAttributes(\n",
|
||||
" a_attrs = cutlass_api.metadata.DenseTensorAttributes(\n",
|
||||
" cutlass.Float64, stride_A, divisibility\n",
|
||||
" )\n",
|
||||
" b_attrs = cutlass_api.metadata.TensorAttributes(\n",
|
||||
" b_attrs = cutlass_api.metadata.DenseTensorAttributes(\n",
|
||||
" cutlass.Float64, stride_B, divisibility\n",
|
||||
" )\n",
|
||||
" out_attrs = cutlass_api.metadata.TensorAttributes(\n",
|
||||
" out_attrs = cutlass_api.metadata.DenseTensorAttributes(\n",
|
||||
" cutlass.Float64, stride_out, divisibility\n",
|
||||
" )\n",
|
||||
" layout_str = cutlass_api.utils.strides_to_layout_string(\n",
|
||||
|
||||
@@ -0,0 +1,240 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "91d43c2b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Grouped GEMM with contiguous tensors via the CUTLASS API\n",
|
||||
"\n",
|
||||
"Note: this notebook requires a GPU with compute capability 100:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "f671f602",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import cutlass_api\n",
|
||||
"\n",
|
||||
"if not (status := cutlass_api.utils.is_device_cc_supported({100})):\n",
|
||||
" print(\n",
|
||||
" f\"This notebook requires a GPU with compute capability 100.\\n{status.error}\"\n",
|
||||
" )\n",
|
||||
" import sys\n",
|
||||
" sys.exit(0)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "bc4adf7d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This notebook shows how to use the CUTLASS API to discover, compile, and execute\n",
|
||||
"kernels supporting contiguous offset grouped GEMMs.\n",
|
||||
"\n",
|
||||
"In a \"contiguous offset\" grouped GEMM, `G` different problems are executed\n",
|
||||
"in which problems differ only in the `M` mode. Their problem sizes are thus\n",
|
||||
"represented as:\n",
|
||||
"\n",
|
||||
"```text\n",
|
||||
"M0 x N x K\n",
|
||||
"M1 x N x K\n",
|
||||
"M2 x N x K\n",
|
||||
"...\n",
|
||||
"M(G-1) x N x K\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"The grouped GEMM is referred to as \"contiguous\" because operands for different\n",
|
||||
"problems in the group are contained within contiguous tensors.\n",
|
||||
"\n",
|
||||
"Rather than having `G` different tensors for each of operands `A` and `B`, tensors\n",
|
||||
"for different problems in the group are packed together:\n",
|
||||
"* `A` is of shape `(TotalM, K)`, where `TotalM` is the sum of all `M` modes for problems in the group.\n",
|
||||
"The `A` operands for each problem in the group are stacked along the `M` mode to form this input. More on this below.\n",
|
||||
"* `B` is of shape `(G, K, N)`, where `B[i, :, :]` represents the GEMM `B` operand for the `i`th problem in the group.\n",
|
||||
"\n",
|
||||
"For example, with `G=3` (three problems in the group), with `M` modes of M0, M1, and M2,\n",
|
||||
"respectively, the tensor `A` would be laid out as follows:\n",
|
||||
"\n",
|
||||
"```text\n",
|
||||
"\n",
|
||||
" +----------------------------------+ ^ \n",
|
||||
" | | | | \n",
|
||||
" | A0 | M0 | \n",
|
||||
" | | | | \n",
|
||||
" |- - - - - - - - - - - -| | \n",
|
||||
" | | | |\n",
|
||||
" | | | TotalM \n",
|
||||
" | A1 | M1 |\n",
|
||||
" | | | |\n",
|
||||
" | | | | \n",
|
||||
" |- - - - - - - - - - - -| | \n",
|
||||
" | A2 | M2 | \n",
|
||||
" +----------------------------------+ v \n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"The extents of individual `A` operands packed within the overall contiguous offset `A` tensor\n",
|
||||
"are provided by an auxiliary `offsets` vector of shape `(G,)`. `offsets[i]` indicates the ending\n",
|
||||
"M coordinate (exclusive) for the `i`th `A` operand.\n",
|
||||
"\n",
|
||||
"Thus, for the example above, `offsets = [M0, M0 + M1, M0 + M1 + M2]`.\n",
|
||||
"\n",
|
||||
"The output of the operation is of shape `(TotalM, N)`. The `i`th output occupies `out[start:end, :]`,\n",
|
||||
"where `start` and `end` are `offsets[i-1]` and `offsets[i]`, respectively (unless `i=0`, in which case\n",
|
||||
"`start` is 0).\n",
|
||||
"\n",
|
||||
"The reference code below shows the computation of this kernel."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "6185f60a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"def reference_contiguous_offset_grouped_gemm(A, B, offsets, out_dtype):\n",
|
||||
" G, K, N = B.shape\n",
|
||||
" TotalM = A.shape[0]\n",
|
||||
"\n",
|
||||
" out = torch.empty((TotalM, N), dtype=out_dtype, device=A.device)\n",
|
||||
"\n",
|
||||
" start = 0\n",
|
||||
" for i in range(G):\n",
|
||||
" end = offsets[i]\n",
|
||||
" out[start:end, :] = A[start:end, :] @ B[i, :, :]\n",
|
||||
" start = end\n",
|
||||
"\n",
|
||||
" return out"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d0bf2f91",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Contiguous offset grouped GEMM in PyTorch"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4308a6a2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The same operation is performed by `torch`'s `torch._grouped_mm` (torch < 2.10)\n",
|
||||
"and `torch.nn.functional.grouped_mm` (torch >= 2.10)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "043906af",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"TotalM = 8192\n",
|
||||
"G = 12\n",
|
||||
"K = 1024\n",
|
||||
"N = 2048\n",
|
||||
"\n",
|
||||
"offsets = torch.arange(TotalM // G, TotalM, TotalM // G, device=\"cuda\", dtype=torch.int32)\n",
|
||||
"offsets[-1] = TotalM\n",
|
||||
"\n",
|
||||
"A = torch.randn(TotalM, K, device=\"cuda\", dtype=torch.bfloat16)\n",
|
||||
"B = torch.randn(G, N, K, device=\"cuda\", dtype=torch.bfloat16).permute(0, 2, 1)\n",
|
||||
"\n",
|
||||
"out_torch = torch._grouped_mm(A, B, offsets, out_dtype=torch.bfloat16)\n",
|
||||
"reference = reference_contiguous_offset_grouped_gemm(A, B, offsets, out_dtype=torch.bfloat16)\n",
|
||||
"\n",
|
||||
"torch.testing.assert_close(out_torch, reference)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0d0e9479",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Contiguous offset grouped GEMM in CUTLASS API\n",
|
||||
"\n",
|
||||
"CUTLASS API exposes this contiguous offset grouped GEMM via `GroupedGemmArguments`,\n",
|
||||
"which are constructed similarly to `GemmArguments`, but take in an `offsets`\n",
|
||||
"tensor as well:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "ff8d3ef1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"out = torch.empty((TotalM, N), device=\"cuda\", dtype=torch.bfloat16)\n",
|
||||
"\n",
|
||||
"args = cutlass_api.arguments.GroupedGemmArguments(\n",
|
||||
" A,\n",
|
||||
" B,\n",
|
||||
" out,\n",
|
||||
" accumulator_type=torch.float32,\n",
|
||||
" offsets=offsets,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0dc6d1cb",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"One can then use the same APIs for finding, compiling, and executing a\n",
|
||||
"kernel supporting this operation"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "80213e1e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"kernels = cutlass_api.get_kernels(args, cc=100)\n",
|
||||
"\n",
|
||||
"assert kernels, \"No kernels found\"\n",
|
||||
"\n",
|
||||
"# Select the first kernel found for simplicity\n",
|
||||
"kernel = kernels[0]\n",
|
||||
"\n",
|
||||
"compiled_kernel = kernel.compile(args)\n",
|
||||
"\n",
|
||||
"# Execute the kernel\n",
|
||||
"kernel.run(args, compiled_artifact=compiled_kernel)\n",
|
||||
"\n",
|
||||
"torch.testing.assert_close(out, reference)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -12,7 +12,7 @@ requires-python = ">=3.12"
|
||||
|
||||
dependencies = [
|
||||
"nvidia-cutlass-dsl",
|
||||
"apache-tvm-ffi", # Required for optimal performance. To opt-out: uninstall or set cutlass_api.config.GlobalOptions().use_tvm_ffi=False
|
||||
"apache-tvm-ffi", # Required for best performance. To opt-out: uninstall or set cutlass_api.config.GlobalOptions().use_tvm_ffi=False
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
||||
@@ -32,6 +32,33 @@ import pytest
|
||||
import cutlass_api
|
||||
from cutlass_api.config import GlobalOptions
|
||||
|
||||
|
||||
# Global variable to store the test level, accessible at collection time
|
||||
_test_level = None
|
||||
|
||||
|
||||
def get_test_level():
|
||||
"""Get the test level set via --level CLI option."""
|
||||
return _test_level
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption(
|
||||
"--level",
|
||||
action="store",
|
||||
default="L0",
|
||||
type=str,
|
||||
choices=["L0", "L1", "L2"],
|
||||
help="Test level to run",
|
||||
)
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
global _test_level
|
||||
|
||||
_test_level = config.getoption("--level", "L0")
|
||||
|
||||
|
||||
#
|
||||
# Before each test, save the GlobalOptions dict
|
||||
# After each test, restore the GlobalOptions dict
|
||||
|
||||
@@ -0,0 +1,332 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import functools
|
||||
import os
|
||||
import pytest
|
||||
import random
|
||||
import torch
|
||||
|
||||
import cutlass
|
||||
|
||||
import cutlass_api
|
||||
from cutlass_api.utils import is_device_cc_supported
|
||||
|
||||
|
||||
torch.manual_seed(2025)
|
||||
random.seed(2025)
|
||||
|
||||
|
||||
def create_offsets(
|
||||
problems_in_group: int,
|
||||
expect_m: int,
|
||||
force_empty_problems: bool,
|
||||
) -> tuple[int, list[int], torch.Tensor]:
|
||||
"""
|
||||
Utility to create offsets for the contiguous offset tensor.
|
||||
|
||||
:param problems_in_group: Number of problems to create offsets for.
|
||||
:type problems_in_group: int
|
||||
:param expect_m: Expected number of rows in each group.
|
||||
:type expect_m: int
|
||||
:param force_empty_problems: Whether to force some problems in the group to be empty.
|
||||
:type force_empty_problems: bool
|
||||
|
||||
:return: Total number of rows, list of group sizes, and offset mapping tensor.
|
||||
:rtype: tuple[int, list[int], torch.Tensor]
|
||||
"""
|
||||
valid_m = 0
|
||||
problem_m_list = []
|
||||
for i in range(problems_in_group):
|
||||
problem_m = random.randint(int(expect_m * 0.7), int(expect_m * 1.3))
|
||||
valid_m += problem_m
|
||||
# handle the case that valid_m == 0
|
||||
if (i == problems_in_group - 1) and (valid_m == 0):
|
||||
problem_m = 128
|
||||
valid_m += problem_m
|
||||
problem_m_list.append(problem_m)
|
||||
|
||||
if force_empty_problems:
|
||||
problems_to_zero = random.sample(
|
||||
range(problems_in_group), max(problems_in_group - 2, 1)
|
||||
)
|
||||
for problem_idx in problems_to_zero:
|
||||
valid_m -= problem_m_list[problem_idx]
|
||||
problem_m_list[problem_idx] = 0
|
||||
|
||||
end_mapping = torch.cumsum(torch.tensor(problem_m_list, dtype=torch.int32), dim=0)
|
||||
offsets = torch.tensor(end_mapping, device="cuda", dtype=torch.int32)
|
||||
return valid_m, problem_m_list, offsets
|
||||
|
||||
|
||||
def args_and_ref_from_metadata(
|
||||
metadata: cutlass_api.metadata.KernelMetadata,
|
||||
mnkl: tuple[int, int, int, int],
|
||||
force_empty_problems: bool,
|
||||
):
|
||||
m, n, k, l = mnkl
|
||||
|
||||
valid_m, problem_m_list, offsets = create_offsets(l, m, force_empty_problems)
|
||||
|
||||
# Create A (1, valid_m, k)
|
||||
if metadata.operands.A.stride[-2:].index(1) == 1:
|
||||
A = torch.randint(-1, 2, (valid_m, k), device="cuda").to(
|
||||
cutlass.torch.dtype(metadata.operands.A.dtype)
|
||||
)
|
||||
else:
|
||||
A = (
|
||||
torch.randint(-1, 2, (k, valid_m), device="cuda")
|
||||
.to(cutlass.torch.dtype(metadata.operands.A.dtype))
|
||||
.T
|
||||
)
|
||||
|
||||
# Create B (l, k, n)
|
||||
if metadata.operands.B.stride[-2:].index(1) == 1:
|
||||
B = torch.randint(-1, 2, (l, k, n), device="cuda").to(
|
||||
cutlass.torch.dtype(metadata.operands.B.dtype)
|
||||
)
|
||||
else:
|
||||
B = (
|
||||
torch.randint(-1, 2, (l, n, k), device="cuda")
|
||||
.to(cutlass.torch.dtype(metadata.operands.B.dtype))
|
||||
.permute(0, 2, 1)
|
||||
)
|
||||
|
||||
# Create out (1, valid_m, n)
|
||||
if metadata.operands.out.stride[-2:].index(1) == 1:
|
||||
out = torch.zeros(
|
||||
(valid_m, n),
|
||||
device="cuda",
|
||||
dtype=cutlass.torch.dtype(metadata.operands.out.dtype),
|
||||
)
|
||||
else:
|
||||
out = torch.zeros(
|
||||
(n, valid_m),
|
||||
device="cuda",
|
||||
dtype=cutlass.torch.dtype(metadata.operands.out.dtype),
|
||||
).T
|
||||
|
||||
args = cutlass_api.arguments.GroupedGemmArguments(
|
||||
A=A,
|
||||
B=B,
|
||||
out=out,
|
||||
accumulator_type=cutlass.torch.dtype(metadata.operands.accumulator_type),
|
||||
offsets=offsets,
|
||||
)
|
||||
|
||||
#
|
||||
# Compute reference
|
||||
#
|
||||
|
||||
out_type = cutlass.torch.dtype(metadata.operands.out.dtype)
|
||||
|
||||
# Compute reference in F32 because torch does not support GEMMs
|
||||
# for all FP8 types
|
||||
if hasattr(torch.nn.functional, "grouped_mm"):
|
||||
reference = torch.nn.functional.grouped_mm(
|
||||
A.float(), B.float(), offs=offsets
|
||||
).to(out_type)
|
||||
else:
|
||||
reference = torch._grouped_mm(A.float(), B.float(), offsets).to(out_type)
|
||||
|
||||
return args, reference, out
|
||||
|
||||
|
||||
def kernels_for_class(kernel_class):
|
||||
def metadata_filter(metadata):
|
||||
return metadata.kernel_class == kernel_class
|
||||
|
||||
kernels = cutlass_api.get_kernels(
|
||||
args=None, metadata_filter=metadata_filter, cc=100
|
||||
)
|
||||
assert len(kernels) > 0
|
||||
|
||||
# Toggle the number of kernels to return based on the test level
|
||||
try:
|
||||
from conftest import get_test_level
|
||||
|
||||
test_level = get_test_level()
|
||||
except ImportError:
|
||||
test_level = "L0"
|
||||
|
||||
if test_level == "L0":
|
||||
return random.sample(kernels, 10)
|
||||
else:
|
||||
return kernels
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mnkl, force_empty_problems",
|
||||
[
|
||||
((13, 16, 32, 3), False),
|
||||
((253, 32, 32, 7), False),
|
||||
((256, 256, 512, 1), False),
|
||||
((256, 256, 512, 10), False),
|
||||
((8192, 8192, 8192, 4), False),
|
||||
((8192, 8192, 8192, 4), True),
|
||||
((1024, 4096, 4096, 15), False),
|
||||
((256, 16384, 128, 20), False),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"kernel",
|
||||
kernels_for_class(
|
||||
cutlass_api.providers.cutedsl.gemm.sm100_contiguous_offset_2d3d_dense_gemm.ContiguousOffset2D3DGemmDenseKernel
|
||||
),
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not is_device_cc_supported({100})
|
||||
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_100f"]),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_100f",
|
||||
)
|
||||
def test_contiguous_offset_gemm_dense(mnkl, force_empty_problems, kernel):
|
||||
m, n, k, l = mnkl
|
||||
|
||||
args, reference, out = args_and_ref_from_metadata(
|
||||
kernel.metadata, mnkl, force_empty_problems
|
||||
)
|
||||
|
||||
if not (status := kernel.supports(args)):
|
||||
pytest.fail(
|
||||
f"Kernel {kernel.metadata.kernel_name} does not support the given arguments: {status}"
|
||||
)
|
||||
|
||||
kernel.run(args)
|
||||
torch.testing.assert_close(
|
||||
out, reference, msg=f"Kernel {kernel.metadata.kernel_name} failed"
|
||||
)
|
||||
|
||||
|
||||
def kernel_metadata_filter(metadata):
|
||||
return (
|
||||
metadata.kernel_class
|
||||
== cutlass_api.providers.cutedsl.gemm.sm100_contiguous_offset_2d3d_dense_gemm.ContiguousOffset2D3DGemmDenseKernel
|
||||
)
|
||||
|
||||
|
||||
def test_incorrect_ab_contiguous():
|
||||
"""
|
||||
Contiguous offset GEMMs currently require A to be in contiguous offset form.
|
||||
Test that no kernels are found when only B is in contiguous offset form.
|
||||
"""
|
||||
|
||||
problem_count, m, n, k = 12, 8192, 128, 512
|
||||
|
||||
A = torch.empty((problem_count, m, k), device="cuda", dtype=torch.float16)
|
||||
B = torch.empty((1, n, k), device="cuda", dtype=torch.float16).permute(0, 2, 1)
|
||||
out = torch.empty((1, m, n), device="cuda", dtype=torch.float32)
|
||||
|
||||
offsets = torch.empty((problem_count,), device="cuda", dtype=torch.int32)
|
||||
|
||||
args = cutlass_api.arguments.GroupedGemmArguments(
|
||||
A=A,
|
||||
B=B,
|
||||
out=out,
|
||||
accumulator_type=torch.float32,
|
||||
offsets=offsets,
|
||||
)
|
||||
|
||||
kernels = cutlass_api.get_kernels(
|
||||
args, metadata_filter=kernel_metadata_filter, cc=100
|
||||
)
|
||||
assert len(kernels) == 0
|
||||
|
||||
|
||||
def test_incorrect_offset_length():
|
||||
"""
|
||||
Offset tensors are required to have `problem_count` elements.
|
||||
|
||||
Test that no kernels are found when this is violated.
|
||||
"""
|
||||
|
||||
problem_count, m, n, k = 12, 8192, 128, 512
|
||||
|
||||
A = torch.empty((1, m, k), device="cuda", dtype=torch.float16)
|
||||
B = torch.empty((problem_count, n, k), device="cuda", dtype=torch.float16).permute(
|
||||
0, 2, 1
|
||||
)
|
||||
out = torch.empty((1, m, n), device="cuda", dtype=torch.float32)
|
||||
|
||||
# Incorrect: should have `problem_count` elements
|
||||
offsets = torch.empty((problem_count + 1,), device="cuda", dtype=torch.int32)
|
||||
|
||||
args = cutlass_api.arguments.GroupedGemmArguments(
|
||||
A=A,
|
||||
B=B,
|
||||
out=out,
|
||||
accumulator_type=torch.float32,
|
||||
offsets=offsets,
|
||||
)
|
||||
|
||||
kernels = cutlass_api.get_kernels(
|
||||
args, metadata_filter=kernel_metadata_filter, cc=100
|
||||
)
|
||||
assert len(kernels) == 0
|
||||
|
||||
|
||||
def test_incorrect_offsets_shape():
|
||||
"""
|
||||
Offset maps are expected to be rank 1. Test that no kernels are found when this is not the case.
|
||||
"""
|
||||
problem_count, m, n, k = 12, 8192, 128, 512
|
||||
|
||||
A = torch.empty((1, m, k), device="cuda", dtype=torch.float16)
|
||||
B = torch.empty((problem_count, n, k), device="cuda", dtype=torch.float16).permute(
|
||||
0, 2, 1
|
||||
)
|
||||
out = torch.empty((1, m, n), device="cuda", dtype=torch.float32)
|
||||
|
||||
# Correct representation of offsets. Kernels should be found.
|
||||
offsets = torch.empty((problem_count,), device="cuda", dtype=torch.int32)
|
||||
|
||||
args = cutlass_api.arguments.GroupedGemmArguments(
|
||||
A=A,
|
||||
B=B,
|
||||
out=out,
|
||||
accumulator_type=torch.float32,
|
||||
offsets=offsets,
|
||||
)
|
||||
|
||||
kernels = cutlass_api.get_kernels(
|
||||
args, metadata_filter=kernel_metadata_filter, cc=100
|
||||
)
|
||||
assert len(kernels) > 0
|
||||
|
||||
# Reformat to be rank 2. No kernels should be found
|
||||
args = cutlass_api.arguments.GroupedGemmArguments(
|
||||
A=A,
|
||||
B=B,
|
||||
out=out,
|
||||
accumulator_type=torch.float32,
|
||||
offsets=offsets.view(3, 4),
|
||||
)
|
||||
|
||||
kernels = cutlass_api.get_kernels(
|
||||
args, metadata_filter=kernel_metadata_filter, cc=100
|
||||
)
|
||||
assert len(kernels) == 0
|
||||
725
python/cutlass_api/test/integration/test_blockscaled_gemm.py
Normal file
725
python/cutlass_api/test/integration/test_blockscaled_gemm.py
Normal file
@@ -0,0 +1,725 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
from collections.abc import Callable
|
||||
from importlib.util import find_spec
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import cutlass
|
||||
|
||||
import cutlass_api
|
||||
from cutlass_api.arguments import EpilogueArguments, ScaledTensor
|
||||
from cutlass_api.config import GlobalOptions
|
||||
from cutlass_api.library import ScaleMode, ScaleSwizzleMode
|
||||
from cutlass_api.metadata import KernelMetadata, Sm100DesignMetadata
|
||||
from cutlass_api.utils import ceil_div, is_device_cc_supported, round_up
|
||||
|
||||
|
||||
torch.manual_seed(2025)
|
||||
|
||||
|
||||
def prep_k(K: int, scale_size: int):
|
||||
"""
|
||||
Prepares K mode for requirements needed for 32-4-4 swizzling
|
||||
|
||||
:param K: The value to prepare.
|
||||
:type K: int
|
||||
:param scale_size: The scale size.
|
||||
:type scale_size: int
|
||||
|
||||
:return: The K mode rounded up to the requirements needed for 32-4-4 swizzling
|
||||
:rtype: int
|
||||
"""
|
||||
return round_up(ceil_div(K, scale_size), 4)
|
||||
|
||||
|
||||
def reference_scaled_mm(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
scale_A: torch.Tensor,
|
||||
scale_B: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
transform_sf: Callable[[torch.Tensor], torch.Tensor] = lambda x : x
|
||||
):
|
||||
"""
|
||||
Computes a reference scaled mm operation. Currently, torch._scaled_mm does not support batch mode.
|
||||
When a batch mode is present, this function iterates through each problem in the batch.
|
||||
|
||||
:param A: The A tensor
|
||||
:type A: torch.Tensor
|
||||
:param B: The B tensor
|
||||
:type B: torch.Tensor
|
||||
:param scale_A: The scale factor tensor for operand A
|
||||
:type scale_A: torch.Tensor
|
||||
:param scale_B: The scale factor tensor for operand B
|
||||
:type scale_B: torch.Tensor
|
||||
:param out_dtype: The output dtype
|
||||
:type out_dtype: torch.dtype
|
||||
:param transform_sf: A function to transform the scale factor tensors to the correct shape for the scaled mm operation
|
||||
:type transform_sf: Callable
|
||||
:return: The reference scaled mm operation
|
||||
"""
|
||||
if len(A.shape) == 2:
|
||||
return torch._scaled_mm(
|
||||
A, B, scale_a=scale_A, scale_b=scale_B, out_dtype=out_dtype
|
||||
)
|
||||
else:
|
||||
# torch._scaled_mm does not support batch mode. Iterate through each problem in the batch
|
||||
L, M, N = A.shape[0], A.shape[1], B.shape[2]
|
||||
reference = torch.empty((L, M, N), device=A.device, dtype=out_dtype)
|
||||
for l in range(L):
|
||||
reference[l, :, :] = torch._scaled_mm(
|
||||
A[l, :, :],
|
||||
B[l, :, :],
|
||||
scale_a=transform_sf(scale_A[l, :, :]),
|
||||
scale_b=transform_sf(scale_B[l, :, :]),
|
||||
out_dtype=out_dtype
|
||||
)
|
||||
return reference
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"M, N, K, L",
|
||||
[
|
||||
(256, 512, 1024, 1),
|
||||
(1024, 640, 512, 2),
|
||||
(256, 512, 1024 + 496, 1), # Test where K is not divisible by scale_size (32)
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"ab_dtype, c_dtype, accumulator_type, scale_dtype",
|
||||
[
|
||||
(torch.float8_e4m3fn, torch.float32, torch.float32, torch.float8_e8m0fnu),
|
||||
(torch.float8_e5m2, torch.float32, torch.float32, torch.float8_e8m0fnu),
|
||||
],
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not is_device_cc_supported({100, 103})
|
||||
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_103a"]),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_103a",
|
||||
)
|
||||
def test_mxfp8_gemm_sm100(
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
L: int,
|
||||
ab_dtype: torch.dtype,
|
||||
c_dtype: torch.dtype,
|
||||
accumulator_type: torch.dtype,
|
||||
scale_dtype: torch.dtype,
|
||||
fixture_enable_tvm_ffi,
|
||||
):
|
||||
# Create torch fp8 tensors for A and B
|
||||
A = torch.randint(-1, 2, (L, M, K), device="cuda").to(ab_dtype)
|
||||
D = torch.empty((L, M, N), device="cuda", dtype=c_dtype)
|
||||
|
||||
# Transpose B because torch._scaled_mm expects B to be column-major
|
||||
B = torch.randint(-1, 2, (L, N, K), device="cuda").to(ab_dtype).transpose(1, 2)
|
||||
|
||||
scale_size = 32
|
||||
|
||||
SFA = torch.rand((L, M, prep_k(K, scale_size),), device="cuda").to(scale_dtype)
|
||||
SFB = torch.rand((L, prep_k(K, scale_size), N), device="cuda").to(scale_dtype)
|
||||
|
||||
args = cutlass_api.arguments.GemmArguments(
|
||||
A=ScaledTensor(
|
||||
A,
|
||||
SFA,
|
||||
ScaleMode.Blockwise1x32,
|
||||
ScaleSwizzleMode.Swizzle32x4x4,
|
||||
),
|
||||
B=ScaledTensor(
|
||||
B,
|
||||
SFB,
|
||||
ScaleMode.Blockwise1x32,
|
||||
ScaleSwizzleMode.Swizzle32x4x4,
|
||||
),
|
||||
out=D,
|
||||
accumulator_type=accumulator_type,
|
||||
)
|
||||
kernels = cutlass_api.get_kernels(args, cc=100)
|
||||
|
||||
assert len(kernels) > 0
|
||||
kernel = kernels[0]
|
||||
assert kernel.supports(args)
|
||||
compiled_artifact = kernel.compile(args)
|
||||
kernel.run(args, compiled_artifact=compiled_artifact, assume_supported_args=True)
|
||||
|
||||
# torch._scaled_mm does not support f8e5m2 * f8e5m2 currently.
|
||||
# Simply skip reference check in that case (but test that a CUTLASS API kernel
|
||||
# is found and runs)
|
||||
if ab_dtype != torch.float8_e5m2:
|
||||
reference = reference_scaled_mm(A, B, SFA, SFB, c_dtype)
|
||||
torch.testing.assert_close(D, reference)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"M, N, K",
|
||||
[
|
||||
(256, 512, 1024),
|
||||
(1024, 640, 512),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"ab_dtype, c_dtype, accumulator_type, scale_dtype",
|
||||
[
|
||||
(torch.float8_e4m3fn, torch.float32, torch.float32, torch.float8_e8m0fnu),
|
||||
],
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not is_device_cc_supported({100, 103})
|
||||
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_103a"]),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_103a",
|
||||
)
|
||||
def test_mxfp8_gemm_sm100_2d(
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
ab_dtype: torch.dtype,
|
||||
c_dtype: torch.dtype,
|
||||
accumulator_type: torch.dtype,
|
||||
scale_dtype: torch.dtype,
|
||||
fixture_enable_tvm_ffi,
|
||||
):
|
||||
"""
|
||||
Tests valid MXFP8 GEMM cases in which A, B, and out are 2D tensors.
|
||||
"""
|
||||
# Create torch fp8 tensors for A and B
|
||||
A = torch.randint(-1, 2, (M, K), device="cuda").to(ab_dtype)
|
||||
D = torch.empty((M, N), device="cuda", dtype=c_dtype)
|
||||
|
||||
# Transpose B because torch._scaled_mm expects B to be column-major
|
||||
B = torch.randint(-1, 2, (N, K), device="cuda").to(ab_dtype).transpose(0, 1)
|
||||
|
||||
scale_size = 32
|
||||
SFA = torch.rand((M, prep_k(K, scale_size),), device="cuda").to(scale_dtype)
|
||||
SFB = torch.rand((prep_k(K, scale_size), N), device="cuda").to(scale_dtype)
|
||||
|
||||
args = cutlass_api.arguments.GemmArguments(
|
||||
A=ScaledTensor(
|
||||
A,
|
||||
SFA,
|
||||
ScaleMode.Blockwise1x32,
|
||||
ScaleSwizzleMode.Swizzle32x4x4,
|
||||
),
|
||||
B=ScaledTensor(
|
||||
B,
|
||||
SFB,
|
||||
ScaleMode.Blockwise1x32,
|
||||
ScaleSwizzleMode.Swizzle32x4x4,
|
||||
),
|
||||
out=D,
|
||||
accumulator_type=accumulator_type,
|
||||
)
|
||||
kernels = cutlass_api.get_kernels(args, cc=100)
|
||||
|
||||
assert len(kernels) > 0
|
||||
kernel = kernels[0]
|
||||
assert kernel.supports(args)
|
||||
compiled_artifact = kernel.compile(args)
|
||||
kernel.run(args, compiled_artifact=compiled_artifact, assume_supported_args=True)
|
||||
|
||||
reference = reference_scaled_mm(A, B, SFA, SFB, c_dtype)
|
||||
torch.testing.assert_close(D, reference)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_device_cc_supported({100, 103})
|
||||
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_103a"]),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_103a",
|
||||
)
|
||||
def test_mxfp8_1d_scale_factors(fixture_enable_tvm_ffi):
|
||||
"""
|
||||
Tests for valid MXFP8 GEMM cases in which A, B, and out are 3D tensors, and the scale factors are 1D tensors.
|
||||
"""
|
||||
M, N, K, L = 256, 512, 1024, 1
|
||||
scale_size = 32
|
||||
|
||||
# Create torch fp8 tensors for A and B
|
||||
A = torch.randint(-1, 2, (L, M, K), device="cuda").to(torch.float8_e4m3fn)
|
||||
B = torch.randint(-1, 2, (L, N, K), device="cuda").to(torch.float8_e4m3fn).transpose(1, 2)
|
||||
D = torch.empty((L, M, N), device="cuda", dtype=torch.float32)
|
||||
|
||||
SFA = torch.rand((L, M, prep_k(K, scale_size),), device="cuda").to(torch.float8_e8m0fnu)
|
||||
SFB = torch.rand((L, prep_k(K, scale_size), N), device="cuda").to(torch.float8_e8m0fnu)
|
||||
|
||||
args = cutlass_api.arguments.GemmArguments(
|
||||
# Pass in SFA and SFB in flattened form (.view(-1))
|
||||
A=ScaledTensor(
|
||||
A,
|
||||
SFA.view(-1),
|
||||
ScaleMode.Blockwise1x32,
|
||||
ScaleSwizzleMode.Swizzle32x4x4,
|
||||
),
|
||||
B=ScaledTensor(
|
||||
B,
|
||||
SFB.view(-1),
|
||||
ScaleMode.Blockwise1x32,
|
||||
ScaleSwizzleMode.Swizzle32x4x4,
|
||||
),
|
||||
out=D,
|
||||
accumulator_type=torch.float32,
|
||||
)
|
||||
|
||||
# Only test kernels for which we know there will be an error if the scale factors are invalid
|
||||
def metadata_filter(metadata: KernelMetadata):
|
||||
return metadata.kernel_class == cutlass_api.providers.cutedsl.gemm.sm100_dense_blockscaled_static_persistent.PersistentDenseBlockScaledGemmKernel
|
||||
|
||||
kernels = cutlass_api.get_kernels(args, cc=100, metadata_filter=metadata_filter)
|
||||
assert len(kernels) > 0
|
||||
kernel = kernels[0]
|
||||
assert kernel.supports(args)
|
||||
compiled_artifact = kernel.compile(args)
|
||||
kernel.run(args, compiled_artifact=compiled_artifact, assume_supported_args=True)
|
||||
|
||||
def transform_sf(x: torch.Tensor) -> torch.Tensor:
|
||||
"""Flattens scale factor tensors after indexing into the batch dimension"""
|
||||
return x.view(-1)
|
||||
|
||||
reference = reference_scaled_mm(A, B, SFA, SFB, torch.float32, transform_sf)
|
||||
torch.testing.assert_close(D, reference)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"discovery_scale_mode, runtime_scale_mode",
|
||||
[
|
||||
((1, 1, 32), (1, 1, 32)),
|
||||
((1, 32), (1, 1, 32)),
|
||||
(ScaleMode.Blockwise1x32, ScaleMode.Blockwise1x32),
|
||||
(ScaleMode.Blockwise1x32, (1, 32)),
|
||||
],
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not is_device_cc_supported({100, 103})
|
||||
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_103a"]),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_103a",
|
||||
)
|
||||
def test_mxfp8_tuple_scale_mode(
|
||||
discovery_scale_mode: tuple[int, ...],
|
||||
runtime_scale_mode: tuple[int, ...],
|
||||
fixture_enable_tvm_ffi,
|
||||
):
|
||||
"""
|
||||
Tests for valid MXFP8 GEMM cases in which A, B, and out are 3D tensors, and
|
||||
the scale mode is specified as a tuple.
|
||||
"""
|
||||
M, N, K, L = 256, 512, 1024, 1
|
||||
scale_size = 32
|
||||
|
||||
# Create torch fp8 tensors for A and B
|
||||
A = torch.randint(-1, 2, (L, M, K), device="cuda").to(torch.float8_e4m3fn)
|
||||
B = (
|
||||
torch.randint(-1, 2, (L, N, K), device="cuda")
|
||||
.to(torch.float8_e4m3fn)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
D = torch.empty((L, M, N), device="cuda", dtype=torch.float32)
|
||||
|
||||
SFA = torch.rand(
|
||||
(
|
||||
L,
|
||||
M,
|
||||
prep_k(K, scale_size),
|
||||
),
|
||||
device="cuda",
|
||||
).to(torch.float8_e8m0fnu)
|
||||
SFB = torch.rand((L, prep_k(K, scale_size), N), device="cuda").to(
|
||||
torch.float8_e8m0fnu
|
||||
)
|
||||
|
||||
discovery_args = cutlass_api.arguments.GemmArguments(
|
||||
A=ScaledTensor(
|
||||
A,
|
||||
SFA,
|
||||
discovery_scale_mode,
|
||||
ScaleSwizzleMode.Swizzle32x4x4,
|
||||
),
|
||||
B=ScaledTensor(
|
||||
B,
|
||||
SFB,
|
||||
discovery_scale_mode,
|
||||
ScaleSwizzleMode.Swizzle32x4x4,
|
||||
),
|
||||
out=D,
|
||||
accumulator_type=torch.float32,
|
||||
)
|
||||
|
||||
# Only test kernels for which we know there will be an error if the scale factors are invalid
|
||||
def metadata_filter(metadata: KernelMetadata):
|
||||
return (
|
||||
metadata.kernel_class
|
||||
== cutlass_api.providers.cutedsl.gemm.sm100_dense_blockscaled_static_persistent.PersistentDenseBlockScaledGemmKernel
|
||||
)
|
||||
|
||||
kernels = cutlass_api.get_kernels(
|
||||
discovery_args, cc=100, metadata_filter=metadata_filter
|
||||
)
|
||||
assert len(kernels) > 0
|
||||
kernel = kernels[0]
|
||||
compiled_artifact = kernel.compile(discovery_args)
|
||||
|
||||
runtime_args = cutlass_api.arguments.GemmArguments(
|
||||
A=ScaledTensor(
|
||||
A,
|
||||
SFA,
|
||||
runtime_scale_mode,
|
||||
ScaleSwizzleMode.Swizzle32x4x4,
|
||||
),
|
||||
B=ScaledTensor(
|
||||
B,
|
||||
SFB,
|
||||
runtime_scale_mode,
|
||||
ScaleSwizzleMode.Swizzle32x4x4,
|
||||
),
|
||||
out=D,
|
||||
accumulator_type=torch.float32,
|
||||
)
|
||||
assert kernel.supports(runtime_args)
|
||||
kernel.run(
|
||||
runtime_args, compiled_artifact=compiled_artifact, assume_supported_args=True
|
||||
)
|
||||
|
||||
reference = reference_scaled_mm(A, B, SFA, SFB, torch.float32)
|
||||
torch.testing.assert_close(D, reference)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_device_cc_supported({100, 103})
|
||||
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_103a"]),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_103a",
|
||||
)
|
||||
def test_mxfp8_incompatible_mode_or_swizzle(fixture_enable_tvm_ffi):
|
||||
"""
|
||||
Tests for correct flagging of incompatible scale modes or swizzles:
|
||||
- Using no swizzle mode for MXFP8
|
||||
- Using an incompatible scale mode for MXFP8
|
||||
"""
|
||||
M, N, K, L = 256, 512, 1024, 1
|
||||
scale_size = 32
|
||||
|
||||
# Create torch fp8 tensors for A and B
|
||||
A = torch.randint(-1, 2, (L, M, K), device="cuda").to(torch.float8_e4m3fn)
|
||||
B = torch.randint(-1, 2, (L, N, K), device="cuda").to(torch.float8_e4m3fn).transpose(1, 2)
|
||||
D = torch.empty((L, M, N), device="cuda", dtype=torch.float32)
|
||||
|
||||
SFA = torch.rand((L, M, prep_k(K, scale_size),), device="cuda").to(torch.float8_e8m0fnu)
|
||||
SFB = torch.rand((L, prep_k(K, scale_size), N), device="cuda").to(torch.float8_e8m0fnu)
|
||||
|
||||
args_bad = cutlass_api.arguments.GemmArguments(
|
||||
# Use incompatible swizzle mode for SFA
|
||||
A=ScaledTensor(
|
||||
A,
|
||||
SFA,
|
||||
ScaleMode.Blockwise1x32,
|
||||
ScaleSwizzleMode.SwizzleNone,
|
||||
),
|
||||
B=ScaledTensor(
|
||||
B,
|
||||
SFB,
|
||||
ScaleMode.Blockwise1x32,
|
||||
ScaleSwizzleMode.Swizzle32x4x4,
|
||||
),
|
||||
out=D,
|
||||
accumulator_type=torch.float32,
|
||||
)
|
||||
|
||||
# Only test kernels for which the swizzle mode is incompatible
|
||||
def metadata_filter(metadata: KernelMetadata):
|
||||
return metadata.kernel_class == cutlass_api.providers.cutedsl.gemm.sm100_dense_blockscaled_static_persistent.PersistentDenseBlockScaledGemmKernel
|
||||
|
||||
kernels = cutlass_api.get_kernels(args_bad, cc=100, metadata_filter=metadata_filter)
|
||||
assert len(kernels) == 0
|
||||
|
||||
# Find kernel using compatible swizzle mode, but pass in arguments at runtime with
|
||||
# incompatible swizzle mode
|
||||
args_good = cutlass_api.arguments.GemmArguments(
|
||||
A=ScaledTensor(
|
||||
A,
|
||||
SFA,
|
||||
ScaleMode.Blockwise1x32,
|
||||
ScaleSwizzleMode.Swizzle32x4x4,
|
||||
),
|
||||
B=ScaledTensor(
|
||||
B,
|
||||
SFB,
|
||||
ScaleMode.Blockwise1x32,
|
||||
ScaleSwizzleMode.Swizzle32x4x4,
|
||||
),
|
||||
out=D,
|
||||
accumulator_type=torch.float32,
|
||||
)
|
||||
kernels = cutlass_api.get_kernels(args_good, cc=100, metadata_filter=metadata_filter)
|
||||
assert len(kernels) > 0
|
||||
kernel = kernels[0]
|
||||
|
||||
try:
|
||||
kernel.run(args_bad)
|
||||
assert False, "Expected ValueError for failed supports() check"
|
||||
except ValueError as e:
|
||||
pass
|
||||
|
||||
# Pass in arguments with an incompatible scale mode
|
||||
args_bad_scale = cutlass_api.arguments.GemmArguments(
|
||||
A=ScaledTensor(
|
||||
A,
|
||||
SFA,
|
||||
ScaleMode.Blockwise1x32,
|
||||
ScaleSwizzleMode.Swizzle32x4x4,
|
||||
),
|
||||
# Bad scale mode: SFB had scale mode Blockwise1x16
|
||||
B=ScaledTensor(
|
||||
B,
|
||||
SFB,
|
||||
ScaleMode.Blockwise1x16,
|
||||
ScaleSwizzleMode.Swizzle32x4x4,
|
||||
),
|
||||
out=D,
|
||||
accumulator_type=torch.float32,
|
||||
)
|
||||
try:
|
||||
kernel.run(args_bad_scale)
|
||||
assert False, "Expected ValueError for failed supports() check"
|
||||
except ValueError as e:
|
||||
pass
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_device_cc_supported({100, 103})
|
||||
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_103a"]),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_103a",
|
||||
)
|
||||
def test_mxfp8_missing_scale_factors(fixture_enable_tvm_ffi):
|
||||
"""
|
||||
Tests for correct flagging of a missing scale factor (both must be supplied)
|
||||
"""
|
||||
M, N, K, L = 256, 512, 1024, 1
|
||||
scale_size = 32
|
||||
|
||||
# Create torch fp8 tensors for A and B
|
||||
A = torch.randint(-1, 2, (L, M, K), device="cuda").to(torch.float8_e4m3fn)
|
||||
B = torch.randint(-1, 2, (L, N, K), device="cuda").to(torch.float8_e4m3fn).transpose(1, 2)
|
||||
D = torch.empty((L, M, N), device="cuda", dtype=torch.float32)
|
||||
SFA = torch.rand((L, M, prep_k(K, scale_size),), device="cuda").to(torch.float8_e8m0fnu)
|
||||
|
||||
# Construct arguments with only one scale factor
|
||||
args = cutlass_api.arguments.GemmArguments(
|
||||
A=ScaledTensor(
|
||||
A,
|
||||
SFA,
|
||||
ScaleMode.Blockwise1x32,
|
||||
ScaleSwizzleMode.Swizzle32x4x4,
|
||||
),
|
||||
B=B,
|
||||
out=D,
|
||||
accumulator_type=torch.float32,
|
||||
)
|
||||
|
||||
# Only test kernels for which we know there will be an error if the scale factors are missing
|
||||
def metadata_filter(metadata: KernelMetadata):
|
||||
return metadata.kernel_class == cutlass_api.providers.cutedsl.gemm.sm100_dense_blockscaled_static_persistent.PersistentDenseBlockScaledGemmKernel
|
||||
|
||||
kernels = cutlass_api.get_kernels(args, cc=100, metadata_filter=metadata_filter)
|
||||
assert len(kernels) == 0
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_device_cc_supported({100, 103})
|
||||
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_103a"]),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_103a",
|
||||
)
|
||||
def test_mxfp8_invalid_scale_factors(fixture_enable_tvm_ffi):
|
||||
"""
|
||||
Tests for correct flagging of invalid scale factors:
|
||||
- Too large for the A tensor
|
||||
- Too large for the B tensor
|
||||
"""
|
||||
M, N, K, L = 256, 512, 1024, 1
|
||||
scale_size = 32
|
||||
|
||||
# Create torch fp8 tensors for A and B
|
||||
A = torch.randint(-1, 2, (L, M, K), device="cuda").to(torch.float8_e4m3fn)
|
||||
B = torch.randint(-1, 2, (L, N, K), device="cuda").to(torch.float8_e4m3fn).transpose(1, 2)
|
||||
D = torch.empty((L, M, N), device="cuda", dtype=torch.float32)
|
||||
|
||||
# Add 32 elements to the K mode of SFA. This makes it too large for the A tensor
|
||||
SFA = torch.rand((L, M, prep_k(K, scale_size) + 32,), device="cuda").to(torch.float8_e8m0fnu)
|
||||
SFB = torch.rand((L, prep_k(K, scale_size), N), device="cuda").to(torch.float8_e8m0fnu)
|
||||
|
||||
args = cutlass_api.arguments.GemmArguments(
|
||||
A=ScaledTensor(
|
||||
A,
|
||||
SFA,
|
||||
ScaleMode.Blockwise1x32,
|
||||
ScaleSwizzleMode.Swizzle32x4x4,
|
||||
),
|
||||
B=ScaledTensor(
|
||||
B,
|
||||
SFB,
|
||||
ScaleMode.Blockwise1x32,
|
||||
ScaleSwizzleMode.Swizzle32x4x4,
|
||||
),
|
||||
out=D,
|
||||
accumulator_type=torch.float32,
|
||||
)
|
||||
|
||||
# Only test kernels for which we know there will be an error if the scale factors are invalid
|
||||
def metadata_filter(metadata: KernelMetadata):
|
||||
return metadata.kernel_class == cutlass_api.providers.cutedsl.gemm.sm100_dense_blockscaled_static_persistent.PersistentDenseBlockScaledGemmKernel
|
||||
|
||||
kernels = cutlass_api.get_kernels(args, cc=100, metadata_filter=metadata_filter)
|
||||
assert len(kernels) == 0
|
||||
|
||||
# Add 32 elements to the K mode of SFB. This makes it too large for the B tensor
|
||||
SFA = torch.rand((L, M, prep_k(K, scale_size),), device="cuda").to(torch.float8_e8m0fnu)
|
||||
SFB = torch.rand((L, prep_k(K, scale_size) + 32, N), device="cuda").to(torch.float8_e8m0fnu)
|
||||
args = cutlass_api.arguments.GemmArguments(
|
||||
A=ScaledTensor(
|
||||
A,
|
||||
SFA,
|
||||
ScaleMode.Blockwise1x32,
|
||||
ScaleSwizzleMode.Swizzle32x4x4,
|
||||
),
|
||||
B=ScaledTensor(
|
||||
B,
|
||||
SFB,
|
||||
ScaleMode.Blockwise1x32,
|
||||
ScaleSwizzleMode.Swizzle32x4x4,
|
||||
),
|
||||
out=D,
|
||||
accumulator_type=torch.float32,
|
||||
)
|
||||
kernels = cutlass_api.get_kernels(args, cc=100, metadata_filter=metadata_filter)
|
||||
assert len(kernels) == 0
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_device_cc_supported({100, 103})
|
||||
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_103a"]),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_103a",
|
||||
)
|
||||
def test_mxfp8_gemm_sm100_invalid_epilogue(fixture_enable_tvm_ffi):
|
||||
"""
|
||||
Tests for correct flagging of invalid epilogues: currently, no MXFP8 kernels support
|
||||
epilogue fusions.
|
||||
"""
|
||||
M, N, K, L = 256, 512, 1024, 1
|
||||
|
||||
# Create torch fp8 tensors for A and B
|
||||
A = torch.randint(-1, 2, (L, M, K), device="cuda").to(torch.float8_e4m3fn)
|
||||
B = torch.randint(-1, 2, (L, N, K), device="cuda").to(torch.float8_e4m3fn).transpose(1, 2)
|
||||
D = torch.empty((L, M, N), device="cuda", dtype=torch.float32)
|
||||
|
||||
scale_size = 32
|
||||
SFA = torch.rand((L, M, prep_k(K, scale_size),), device="cuda").to(torch.float8_e8m0fnu)
|
||||
SFB = torch.rand((L, prep_k(K, scale_size), N), device="cuda").to(torch.float8_e8m0fnu)
|
||||
|
||||
C = torch.randint(-1, 2, (L, M, N), device="cuda").to(torch.float32)
|
||||
def epilogue(accum, alpha, beta, C):
|
||||
D = alpha * accum + beta * C
|
||||
return D
|
||||
|
||||
epi_args = EpilogueArguments(epilogue, alpha=1.0, beta=1.0, C=C, D=D)
|
||||
|
||||
args = cutlass_api.arguments.GemmArguments(
|
||||
A=ScaledTensor(
|
||||
A,
|
||||
SFA,
|
||||
ScaleMode.Blockwise1x32,
|
||||
ScaleSwizzleMode.Swizzle32x4x4,
|
||||
),
|
||||
B=ScaledTensor(
|
||||
B,
|
||||
SFB,
|
||||
ScaleMode.Blockwise1x32,
|
||||
ScaleSwizzleMode.Swizzle32x4x4,
|
||||
),
|
||||
out=D,
|
||||
accumulator_type=torch.float32,
|
||||
epilogue=epi_args,
|
||||
)
|
||||
kernels = cutlass_api.get_kernels(args, cc=100)
|
||||
|
||||
assert len(kernels) == 0
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_device_cc_supported({100, 103})
|
||||
or (os.getenv("CUTE_DSL_ARCH", "") not in ["", "sm_100a", "sm_103a"]),
|
||||
reason="Requires compute capability 100 and to be compiled with sm_100a or sm_103a",
|
||||
)
|
||||
def test_mxfp8_gemm_sm100_design_metadata(fixture_enable_tvm_ffi):
|
||||
"""
|
||||
Tests for correct filtering of kernels based on design metadata specifications
|
||||
(e.g., tile shape, cluster shape, etc.).
|
||||
"""
|
||||
M, N, K, L = 256, 512, 1024, 1
|
||||
|
||||
# Create torch fp8 tensors for A and B
|
||||
A = torch.randint(-1, 2, (L, M, K), device="cuda").to(torch.float8_e4m3fn)
|
||||
B = torch.randint(-1, 2, (L, N, K), device="cuda").to(torch.float8_e4m3fn).transpose(1, 2)
|
||||
D = torch.empty((L, M, N), device="cuda", dtype=torch.float32)
|
||||
|
||||
scale_size = 32
|
||||
SFA = torch.rand((L, M, prep_k(K, scale_size),), device="cuda").to(torch.float8_e8m0fnu)
|
||||
SFB = torch.rand((L, prep_k(K, scale_size), N), device="cuda").to(torch.float8_e8m0fnu)
|
||||
|
||||
args = cutlass_api.arguments.GemmArguments(
|
||||
A=ScaledTensor(
|
||||
A,
|
||||
SFA,
|
||||
ScaleMode.Blockwise1x32,
|
||||
ScaleSwizzleMode.Swizzle32x4x4,
|
||||
),
|
||||
B=ScaledTensor(
|
||||
B,
|
||||
SFB,
|
||||
ScaleMode.Blockwise1x32,
|
||||
ScaleSwizzleMode.Swizzle32x4x4,
|
||||
),
|
||||
out=D,
|
||||
accumulator_type=torch.float32,
|
||||
)
|
||||
|
||||
def design_filter(metadata: KernelMetadata):
|
||||
if not isinstance(metadata.design, Sm100DesignMetadata):
|
||||
return False
|
||||
return metadata.design.tile_shape[:2] == (256, 128)
|
||||
|
||||
kernels = cutlass_api.get_kernels(args, cc=100, metadata_filter=design_filter)
|
||||
assert len(kernels) > 0
|
||||
|
||||
for kernel in kernels:
|
||||
assert design_filter(kernel.metadata)
|
||||
|
||||
kernel = kernels[0]
|
||||
assert kernel.supports(args)
|
||||
compiled_artifact = kernel.compile(args)
|
||||
kernel.run(args, compiled_artifact=compiled_artifact)
|
||||
|
||||
reference = reference_scaled_mm(A, B, SFA, SFB, torch.float32)
|
||||
torch.testing.assert_close(D, reference)
|
||||
@@ -0,0 +1,69 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import functools
|
||||
import os
|
||||
import pytest
|
||||
import random
|
||||
import torch
|
||||
|
||||
import cutlass_api
|
||||
|
||||
|
||||
torch.manual_seed(2025)
|
||||
random.seed(2025)
|
||||
|
||||
|
||||
def test_incorrect_offset_length():
|
||||
"""
|
||||
Offset tensors are required to have `problem_count` elements.
|
||||
|
||||
Test that no kernels are found when this is violated.
|
||||
"""
|
||||
|
||||
problem_count, m, n, k = 12, 8192, 128, 512
|
||||
|
||||
A = torch.empty((1, m, k), device="cuda", dtype=torch.float16)
|
||||
B = torch.empty((problem_count, n, k), device="cuda", dtype=torch.float16).permute(
|
||||
0, 2, 1
|
||||
)
|
||||
out = torch.empty((1, m, n), device="cuda", dtype=torch.float32)
|
||||
|
||||
# Incorrect: should have `problem_count` elements
|
||||
offsets = torch.empty((problem_count + 1,), device="cuda", dtype=torch.int32)
|
||||
|
||||
args = cutlass_api.arguments.GroupedGemmArguments(
|
||||
A=A,
|
||||
B=B,
|
||||
out=out,
|
||||
accumulator_type=torch.float32,
|
||||
offsets=offsets,
|
||||
)
|
||||
|
||||
kernels = cutlass_api.get_kernels(args, cc=100)
|
||||
assert len(kernels) == 0
|
||||
@@ -37,11 +37,12 @@ import cutlass_api
|
||||
@pytest.mark.parametrize(
|
||||
"notebook_name, supported_ccs",
|
||||
[
|
||||
("000_gemm.ipynb", [80, 89, 90,100, 103]),
|
||||
("000_gemm.ipynb", [80, 89, 90, 100, 103]),
|
||||
("001_gemm_with_fused_epilogue.ipynb", [100, 103]),
|
||||
("002_bring_your_own_kernel.ipynb", [80, 89, 90, 100, 103, 120, 121]),
|
||||
("003_host_latency_best_practices.ipynb", [80, 89, 90, 100, 103]),
|
||||
("004_fake_tensors.ipynb", [80, 89, 90, 100, 103]),
|
||||
("005_grouped_gemm_contiguous_offset.ipynb", [100]),
|
||||
],
|
||||
)
|
||||
def test_notebooks(notebook_name, supported_ccs):
|
||||
|
||||
@@ -37,7 +37,7 @@ from cutlass_api.arguments import ElementwiseArguments
|
||||
from cutlass_api.config import GlobalOptions
|
||||
from cutlass_api.metadata import (
|
||||
ElementwiseOperandsMetadata,
|
||||
TensorAttributes,
|
||||
DenseTensorAttributes,
|
||||
)
|
||||
|
||||
|
||||
@@ -49,7 +49,9 @@ class NoopKernelForTesting(cutlass_api.providers.cutedsl.kernel.CuteDslKernel):
|
||||
|
||||
def compile(self, args: ElementwiseArguments):
|
||||
stream = cute.runtime.make_fake_stream()
|
||||
return self.cute_compile(self.impl, args.A, args.B, args.out, stream)
|
||||
return self.cute_compile(
|
||||
self.impl, args.A.tensor, args.B.tensor, args.out.tensor, stream
|
||||
)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
@@ -58,10 +60,12 @@ class NoopKernelForTesting(cutlass_api.providers.cutedsl.kernel.CuteDslKernel):
|
||||
stream,
|
||||
workspace=None,
|
||||
):
|
||||
self.cute_run(compiled_artifact, args.A, args.B, args.out, stream)
|
||||
self.cute_run(
|
||||
compiled_artifact, args.A.tensor, args.B.tensor, args.out.tensor, stream
|
||||
)
|
||||
|
||||
def generate_kernels(_ignored_filter, _ignored_epilogue_args, _ignored_cc):
|
||||
attrs = TensorAttributes(
|
||||
attrs = DenseTensorAttributes(
|
||||
stride=(0, 1),
|
||||
dtype=cutlass.Float16,
|
||||
divisibility=8,
|
||||
|
||||
Reference in New Issue
Block a user