v4.2 release. (#2587)

* Fix default cluster callback values to 1 to avoid profiler failure when these values are not set in command line.

* v4.2 release.
This commit is contained in:
Junkai-Wu
2025-08-23 06:11:24 +08:00
committed by GitHub
parent 11cad1f67b
commit a49a78ffef
351 changed files with 28182 additions and 2032 deletions

View File

@@ -4,7 +4,7 @@
This directory contains Python packages that are associated with CUTLASS:
* `cutlass`: the CUTLASS Python interface, which enables one to compile and run CUTLASS kernels from within Python
* `cutlass_cppgen`: the CUTLASS Python interface, which enables one to compile and run CUTLASS kernels from within Python. Note that this was previously named `cutlass`, but was renamed to disambiguate with the CuTe Python DSL.
* `cutlass_library`: utilities used for enumerating and emitting C++ code for CUTLASS kernels
## CUTLASS Python Interface

View File

@@ -119,8 +119,8 @@ def set_log_level(level: int):
set_log_level(logging.ERROR)
from cutlass.library_defaults import OptionRegistry
from cutlass.backend.utils.device import device_cc
from cutlass_cppgen.library_defaults import OptionRegistry
from cutlass_cppgen.backend.utils.device import device_cc
this._option_registry = None
def get_option_registry():
@@ -135,14 +135,14 @@ def get_option_registry():
this.__version__ = '4.1.0'
from cutlass.backend import create_memory_pool
from cutlass.emit.pytorch import pytorch
from cutlass.op.gemm import Gemm
from cutlass.op.conv import Conv2d, Conv2dFprop, Conv2dDgrad, Conv2dWgrad
from cutlass.op.gemm_grouped import GroupedGemm
from cutlass.op.op import OperationBase
from cutlass.backend.evt.ir.tensor import Tensor
from cutlass.utils.lazy_import import lazy_import
from cutlass_cppgen.backend import create_memory_pool
from cutlass_cppgen.emit.pytorch import pytorch
from cutlass_cppgen.op.gemm import Gemm
from cutlass_cppgen.op.conv import Conv2d, Conv2dFprop, Conv2dDgrad, Conv2dWgrad
from cutlass_cppgen.op.gemm_grouped import GroupedGemm
from cutlass_cppgen.op.op import OperationBase
from cutlass_cppgen.backend.evt.ir.tensor import Tensor
from cutlass_cppgen.utils.lazy_import import lazy_import
this.memory_pool = None

View File

@@ -30,19 +30,19 @@
#
#################################################################################################
from cutlass.backend.arguments import *
from cutlass.backend.c_types import *
from cutlass.backend.compiler import ArtifactManager
from cutlass.backend.conv2d_operation import *
from cutlass.backend.epilogue import *
from cutlass.backend.frontend import *
from cutlass.backend.gemm_operation import *
from cutlass.backend.library import *
from cutlass.backend.memory_manager import PoolMemoryManager, create_memory_pool
from cutlass.backend.operation import *
from cutlass.backend.reduction_operation import *
from cutlass.backend.type_hint import *
from cutlass.backend.utils import *
from cutlass.backend.utils.device import device_cc
from cutlass_cppgen.backend.arguments import *
from cutlass_cppgen.backend.c_types import *
from cutlass_cppgen.backend.compiler import ArtifactManager
from cutlass_cppgen.backend.conv2d_operation import *
from cutlass_cppgen.backend.epilogue import *
from cutlass_cppgen.backend.frontend import *
from cutlass_cppgen.backend.gemm_operation import *
from cutlass_cppgen.backend.library import *
from cutlass_cppgen.backend.memory_manager import PoolMemoryManager, create_memory_pool
from cutlass_cppgen.backend.operation import *
from cutlass_cppgen.backend.reduction_operation import *
from cutlass_cppgen.backend.type_hint import *
from cutlass_cppgen.backend.utils import *
from cutlass_cppgen.backend.utils.device import device_cc
compiler = ArtifactManager()

View File

@@ -33,16 +33,16 @@
from math import prod
from typing import Union
from cutlass.utils.lazy_import import lazy_import
from cutlass_cppgen.utils.lazy_import import lazy_import
cuda = lazy_import("cuda.cuda")
cudart = lazy_import("cuda.cudart")
import numpy as np
import cutlass
from cutlass.backend.frontend import CupyFrontend, NumpyFrontend, TorchFrontend
from cutlass.backend.memory_manager import DevicePtrWrapper
from cutlass.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_tensor
import cutlass_cppgen
from cutlass_cppgen.backend.frontend import CupyFrontend, NumpyFrontend, TorchFrontend
from cutlass_cppgen.backend.memory_manager import DevicePtrWrapper
from cutlass_cppgen.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_tensor
class ArgumentBase:
@@ -122,7 +122,7 @@ class ArgumentBase:
Frees allocated device-side memory
"""
# Free any device memory allocated manually
if not cutlass.use_rmm:
if not cutlass_cppgen.use_rmm:
for name, buf in self.buffers.items():
if isinstance(buf, DevicePtrWrapper):
err, = cudart.cudaFree(buf.ptr)

View File

@@ -37,7 +37,7 @@ from cutlass_library import (
KernelScheduleType,
TileSchedulerType
)
from cutlass.backend.library import DataTypeSizeBytes
from cutlass_cppgen.backend.library import DataTypeSizeBytes
class GemmCoord_(ctypes.Structure):

View File

@@ -37,17 +37,17 @@ import sqlite3
import subprocess
import tempfile
from cutlass.utils.lazy_import import lazy_import
from cutlass_cppgen.utils.lazy_import import lazy_import
cuda = lazy_import("cuda.cuda")
cudart = lazy_import("cuda.cudart")
nvrtc = lazy_import("cuda.nvrtc")
from cutlass_library import SubstituteTemplate
import cutlass
from cutlass import CACHE_FILE, CUTLASS_PATH, cuda_install_path, logger
from cutlass.backend.gemm_operation import GemmOperationUniversal
from cutlass.backend.library import ApiVersion
from cutlass.backend.utils.device import device_cc
import cutlass_cppgen
from cutlass_cppgen import CACHE_FILE, CUTLASS_PATH, cuda_install_path, logger
from cutlass_cppgen.backend.gemm_operation import GemmOperationUniversal
from cutlass_cppgen.backend.library import ApiVersion
from cutlass_cppgen.backend.utils.device import device_cc
IncludeTemplate = r"""#include "${include}"
"""
@@ -93,7 +93,7 @@ class CompilationOptions:
opts.append(f"--include-path={incl}")
arch_flag = f"-arch=sm_{self.arch}"
if self.arch == 90 and int(cutlass.nvcc_version().split('.')[0]) >= 12:
if self.arch == 90 and int(cutlass_cppgen.nvcc_version().split('.')[0]) >= 12:
arch_flag += "a"
opts.append(arch_flag)
@@ -366,7 +366,7 @@ class ArtifactManager:
CUTLASS_PATH + "/python/cutlass/cpp/include",
]
cutlass.initialize_cuda_context()
cutlass_cppgen.initialize_cuda_context()
arch = device_cc()
host_compile_options = CompilationOptions(

View File

@@ -34,7 +34,7 @@ from __future__ import annotations
import ctypes
from typing import Union
from cutlass.utils.lazy_import import lazy_import
from cutlass_cppgen.utils.lazy_import import lazy_import
cuda = lazy_import("cuda.cuda")
from cutlass_library import SubstituteTemplate
import numpy as np
@@ -65,17 +65,17 @@ from cutlass_library import (
get_complex_from_real,
)
from cutlass.backend.arguments import ArgumentBase
from cutlass.backend.c_types import dim3_, get_conv2d_arguments
from cutlass.backend.library import (
from cutlass_cppgen.backend.arguments import ArgumentBase
from cutlass_cppgen.backend.c_types import dim3_, get_conv2d_arguments
from cutlass_cppgen.backend.library import (
EmissionType,
TensorDescription,
TileDescription,
)
from cutlass.backend.memory_manager import device_mem_alloc
from cutlass.backend.operation import ExecutableOperation, LaunchConfiguration
from cutlass.backend.utils.device import to_device_ptr
from cutlass.shape import GemmCoord
from cutlass_cppgen.backend.memory_manager import device_mem_alloc
from cutlass_cppgen.backend.operation import ExecutableOperation, LaunchConfiguration
from cutlass_cppgen.backend.utils.device import to_device_ptr
from cutlass_cppgen.shape import GemmCoord
class Conv2dArguments(ArgumentBase):
@@ -84,9 +84,9 @@ class Conv2dArguments(ArgumentBase):
user-provide tensors into the kernel's argument.
:param operation: the Conv2d operation to take the argument
:type operation: :class:`cutlass.backend.Conv2dOperation`
:type operation: :class:`cutlass_cppgen.backend.Conv2dOperation`
:param problem_size: the Conv2d problem size
:type problem_size: :class:`cutlass.shape.Conv2dProblemSize`
:type problem_size: :class:`cutlass_cppgen.shape.Conv2dProblemSize`
:param A: tensor A
:type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
:param B: tensor B
@@ -98,7 +98,7 @@ class Conv2dArguments(ArgumentBase):
:param split_k_mode: conv2d split K mode, defaults to cutlass_library.library.SplitKMode.Serial
:type split_k_mode: cutlass_library.library.SplitKMode, optional
:param output_op: output operator, optional
:type output_op: :class:`cutlass.backend.LinearCombinationFunctorArguments`
:type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments`
:param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
:type stream: :class:`cuda.cuda.CUstream`
"""
@@ -380,19 +380,19 @@ class Conv2dOperation:
:type arch: int
:param tile_description: tile description
:type tile_description: :class:`cutlass.backend.TileDescription`
:type tile_description: :class:`cutlass_cppgen.backend.TileDescription`
:param A: tensor A description
:type A: :class:`cutlass.backend.TensorDescription`
:type A: :class:`cutlass_cppgen.backend.TensorDescription`
:param B: tensor B description
:type B: :class:`cutlass.backend.TensorDescription`
:type B: :class:`cutlass_cppgen.backend.TensorDescription`
:param C: tensor C description
:type C: :class:`cutlass.backend.TensorDescription`
:type C: :class:`cutlass_cppgen.backend.TensorDescription`
:param D: tensor D description
:type D: :class:`cutlass.backend.TensorDescription`
:type D: :class:`cutlass_cppgen.backend.TensorDescription`
:param element_epilogue: element type for computation in epilogue \
:type element_epilogue: cutlass_library.library.DataType
@@ -444,7 +444,7 @@ class Conv2dOperation:
Launch the cuda kernel with input arguments
:param arguments: conv2d arguments
:type arguments: :class:`cutlass.backend.Conv2dArguments`
:type arguments: :class:`cutlass_cppgen.backend.Conv2dArguments`
"""
# launch the kernel

View File

@@ -36,10 +36,10 @@ from cutlass_library import SubstituteTemplate
import numpy as np
from cutlass_library import DataType, DataTypeTag
from cutlass.backend.c_types import MatrixCoord_, tuple_factory
from cutlass.backend.frontend import NumpyFrontend
from cutlass.backend.library import ActivationOp, ActivationOpTag
from cutlass.utils.datatypes import is_numpy_tensor, is_torch_available, is_torch_tensor
from cutlass_cppgen.backend.c_types import MatrixCoord_, tuple_factory
from cutlass_cppgen.backend.frontend import NumpyFrontend
from cutlass_cppgen.backend.library import ActivationOp, ActivationOpTag
from cutlass_cppgen.utils.datatypes import is_numpy_tensor, is_torch_available, is_torch_tensor
dtype2ctype = {
DataType.f16: ctypes.c_uint16,

View File

@@ -30,5 +30,5 @@
#
#################################################################################################
from cutlass.backend.evt.epilogue import EpilogueFunctorVisitor
from cutlass.backend.evt.frontend import PythonASTFrontend
from cutlass_cppgen.backend.evt.epilogue import EpilogueFunctorVisitor
from cutlass_cppgen.backend.evt.frontend import PythonASTFrontend

View File

@@ -30,7 +30,7 @@
#
#################################################################################################
from cutlass.backend.evt.backend.sm80_emitter import Sm80Emitter
import cutlass.backend.evt.backend.sm80_nodes as sm80_nodes
from cutlass.backend.evt.backend.sm90_emitter import Sm90Emitter
import cutlass.backend.evt.backend.sm90_nodes as sm90_nodes
from cutlass_cppgen.backend.evt.backend.sm80_emitter import Sm80Emitter
import cutlass_cppgen.backend.evt.backend.sm80_nodes as sm80_nodes
from cutlass_cppgen.backend.evt.backend.sm90_emitter import Sm90Emitter
import cutlass_cppgen.backend.evt.backend.sm90_nodes as sm90_nodes

View File

@@ -35,7 +35,7 @@ Base class for Epilogue Visitor Emitter
"""
from cutlass_library import DataTypeTag
from cutlass.backend.evt.ir import TopoVisitorNode, DAGIR
from cutlass_cppgen.backend.evt.ir import TopoVisitorNode, DAGIR
class FusionCallbacks:

View File

@@ -34,8 +34,8 @@
Emitter for Sm80 Epilogue Visitor
"""
from cutlass.backend.evt.backend.emitter_base import FusionCallbacks
from cutlass.backend import GemmOperationUniversal
from cutlass_cppgen.backend.evt.backend.emitter_base import FusionCallbacks
from cutlass_cppgen.backend import GemmOperationUniversal
class Sm80Emitter:

View File

@@ -32,7 +32,7 @@
from cutlass_library import DataTypeSize, DataTypeTag
from cutlass.backend.evt.ir import (
from cutlass_cppgen.backend.evt.ir import (
# Load Node
AccumulatorImpl,
AuxLoadImpl,
@@ -50,7 +50,7 @@ from cutlass.backend.evt.ir import (
ScalarReductionImpl
)
from cutlass.backend.library import (
from cutlass_cppgen.backend.library import (
FloatRoundStyleTag,
FunctionalOp,
op_tag,

View File

@@ -35,8 +35,8 @@ Emitter for Sm90 Epilogue Visitor
"""
from cutlass_library import DataTypeTag, EpilogueScheduleTag
from cutlass.backend import GemmOperationUniversal
from cutlass.backend.evt.backend.emitter_base import FusionCallbacks
from cutlass_cppgen.backend import GemmOperationUniversal
from cutlass_cppgen.backend.evt.backend.emitter_base import FusionCallbacks
class CollectiveEpilogue:

View File

@@ -33,7 +33,7 @@
from pycute import product
from cutlass_library import DataTypeSize, DataTypeTag
from cutlass.backend.evt.ir import (
from cutlass_cppgen.backend.evt.ir import (
# Load Node
AccumulatorImpl,
AuxLoadImpl,
@@ -53,7 +53,7 @@ from cutlass.backend.evt.ir import (
StoreNode,
StoreDImpl,
)
from cutlass.backend.library import (
from cutlass_cppgen.backend.library import (
FloatRoundStyleTag,
FunctionalOp,
op_tag,

View File

@@ -36,16 +36,16 @@ Epilogue Visitor interface for compiling, and running visitor-based epilogue.
import ctypes
from cutlass.utils.lazy_import import lazy_import
from cutlass_cppgen.utils.lazy_import import lazy_import
cuda = lazy_import("cuda.cuda")
from cutlass_library import DataType
import numpy as np
from cutlass.backend.epilogue import EpilogueFunctorBase
import cutlass.backend.evt.backend
from cutlass.backend.frontend import TensorFrontend
from cutlass.utils.datatypes import is_numpy_tensor
from cutlass.backend.evt.passes.util import cc_map
from cutlass_cppgen.backend.epilogue import EpilogueFunctorBase
import cutlass_cppgen.backend.evt.backend
from cutlass_cppgen.backend.frontend import TensorFrontend
from cutlass_cppgen.utils.datatypes import is_numpy_tensor
from cutlass_cppgen.backend.evt.passes.util import cc_map
class EpilogueFunctorVisitor(EpilogueFunctorBase):
@@ -58,7 +58,7 @@ class EpilogueFunctorVisitor(EpilogueFunctorBase):
"""
def __init__(self, cc: int, visitor, element_compute=DataType.f32) -> None:
# Type of Emitter based on CC
self.emit_cls = getattr(cutlass.backend.evt.backend, f"Sm{cc_map[cc]}Emitter")
self.emit_cls = getattr(cutlass_cppgen.backend.evt.backend, f"Sm{cc_map[cc]}Emitter")
# Visitor Types
self.visitor = visitor

View File

@@ -30,4 +30,4 @@
#
#################################################################################################
from cutlass.backend.evt.frontend.python_ast import PythonASTFrontend
from cutlass_cppgen.backend.evt.frontend.python_ast import PythonASTFrontend

View File

@@ -37,14 +37,14 @@ Base class for Python EVT Frontend
from typing import Union
from cutlass_library import DataType
from cutlass.backend.evt.ir import (
from cutlass_cppgen.backend.evt.ir import (
ComputeNode,
DAGIR,
LayoutNode,
LoadNode,
StoreNode,
)
from cutlass.backend.evt.passes import (
from cutlass_cppgen.backend.evt.passes import (
EVTGraphDrawer,
EVTPassManager,
GetSmemSize,
@@ -56,9 +56,9 @@ from cutlass.backend.evt.passes import (
PassPreprocessRed,
PassShapeTypePropagation,
)
from cutlass.backend.utils import device_cc
from cutlass.epilogue.evt_ops import permute, reshape
from cutlass.utils.datatypes import library_type
from cutlass_cppgen.backend.utils import device_cc
from cutlass_cppgen.epilogue.evt_ops import permute, reshape
from cutlass_cppgen.utils.datatypes import library_type
class EVTFrontendBase:

View File

@@ -40,10 +40,10 @@ import textwrap
from cutlass_library import DataType
import cutlass
from cutlass.backend.evt.frontend.frontend_base import EVTFrontendBase
from cutlass.backend.epilogue import identity, relu, tanh, sigmoid, silu, hardswish, gelu
from cutlass.backend.library import FunctionalOp
import cutlass_cppgen
from cutlass_cppgen.backend.evt.frontend.frontend_base import EVTFrontendBase
from cutlass_cppgen.backend.epilogue import identity, relu, tanh, sigmoid, silu, hardswish, gelu
from cutlass_cppgen.backend.library import FunctionalOp
class PythonASTFrontend(EVTFrontendBase, ast.NodeVisitor):

View File

@@ -30,10 +30,10 @@
#
#################################################################################################
from cutlass.backend.evt.ir.compute_nodes import ComputeNode, ComputeImpl
from cutlass.backend.evt.ir.dag_ir import DAGIR
from cutlass.backend.evt.ir.layout_nodes import LayoutNode
from cutlass.backend.evt.ir.load_nodes import (
from cutlass_cppgen.backend.evt.ir.compute_nodes import ComputeNode, ComputeImpl
from cutlass_cppgen.backend.evt.ir.dag_ir import DAGIR
from cutlass_cppgen.backend.evt.ir.layout_nodes import LayoutNode
from cutlass_cppgen.backend.evt.ir.load_nodes import (
LoadNode,
AccumulatorImpl,
LoadSrcImpl,
@@ -42,8 +42,8 @@ from cutlass.backend.evt.ir.load_nodes import (
ColumnBroadcastImpl,
ScalarBroadcastImpl
)
from cutlass.backend.evt.ir.node import TopoVisitorNode, NoOpImpl
from cutlass.backend.evt.ir.store_nodes import (
from cutlass_cppgen.backend.evt.ir.node import TopoVisitorNode, NoOpImpl
from cutlass_cppgen.backend.evt.ir.store_nodes import (
StoreNode,
StoreDImpl,
AuxStoreImpl,

View File

@@ -34,8 +34,8 @@
Python registration for compute nodes in EVT
"""
from cutlass.backend.evt.ir.node import NodeBase, ImplBase
from cutlass.backend.library import FloatRoundStyle
from cutlass_cppgen.backend.evt.ir.node import NodeBase, ImplBase
from cutlass_cppgen.backend.library import FloatRoundStyle
class ComputeImplBase(ImplBase):

View File

@@ -38,10 +38,10 @@ import networkx as nx
from cutlass_library import DataType
from cutlass.backend.evt.ir.compute_nodes import ComputeNode
from cutlass.backend.evt.ir.node import NodeBase
from cutlass.backend.library import ActivationOp
from cutlass.backend.utils import device_cc
from cutlass_cppgen.backend.evt.ir.compute_nodes import ComputeNode
from cutlass_cppgen.backend.evt.ir.node import NodeBase
from cutlass_cppgen.backend.library import ActivationOp
from cutlass_cppgen.backend.utils import device_cc
class DAGIR:

View File

@@ -41,10 +41,10 @@ from copy import deepcopy
from cutlass_library import LayoutType
from pycute import product, flatten
import cutlass
from cutlass.backend.evt.ir.layout_algorithm import _list_to_tuple, _tuple_to_list
from cutlass.backend.evt.ir.node import NodeBase
from cutlass.backend.evt.ir.tensor import Tensor
import cutlass_cppgen
from cutlass_cppgen.backend.evt.ir.layout_algorithm import _list_to_tuple, _tuple_to_list
from cutlass_cppgen.backend.evt.ir.node import NodeBase
from cutlass_cppgen.backend.evt.ir.tensor import Tensor
class PermutationImpl:

View File

@@ -36,9 +36,9 @@ Load nodes and implementations
import ctypes
from cutlass.backend.c_types import tuple_factory
from cutlass.backend.epilogue import dtype2ctype, to_ctype_value
from cutlass.backend.evt.ir.node import NodeBase, ImplBase
from cutlass_cppgen.backend.c_types import tuple_factory
from cutlass_cppgen.backend.epilogue import dtype2ctype, to_ctype_value
from cutlass_cppgen.backend.evt.ir.node import NodeBase, ImplBase
class LoadImplBase(ImplBase):

View File

@@ -39,8 +39,8 @@ from re import sub
from cutlass_library import LayoutType
from cutlass.backend.evt.ir.layout_algorithm import _list_to_tuple, _reverse_tuple
from cutlass.backend.evt.ir.tensor import Tensor
from cutlass_cppgen.backend.evt.ir.layout_algorithm import _list_to_tuple, _reverse_tuple
from cutlass_cppgen.backend.evt.ir.tensor import Tensor
class ImplBase:
@@ -170,7 +170,7 @@ class NodeBase:
@property
def tensor(self) -> Tensor:
"""
Return the output tensor (concept: cutlass.backend.evt.ir.tensor)
Return the output tensor (concept: cutlass_cppgen.backend.evt.ir.tensor)
"""
return self._tensor

View File

@@ -38,11 +38,11 @@ import ctypes
from cutlass_library import DataType
from cutlass.backend.c_types import tuple_factory
from cutlass.backend.epilogue import dtype2ctype, to_ctype_value
from cutlass.backend.evt.ir.node import NodeBase, ImplBase, NoOpImpl
from cutlass.backend.evt.ir.tensor import Tensor
from cutlass.backend.library import FloatRoundStyle, FunctionalOp
from cutlass_cppgen.backend.c_types import tuple_factory
from cutlass_cppgen.backend.epilogue import dtype2ctype, to_ctype_value
from cutlass_cppgen.backend.evt.ir.node import NodeBase, ImplBase, NoOpImpl
from cutlass_cppgen.backend.evt.ir.tensor import Tensor
from cutlass_cppgen.backend.library import FloatRoundStyle, FunctionalOp
class StoreImplBase(ImplBase):
@@ -249,7 +249,7 @@ class StoreNode(NodeBase):
@property
def store_tensor(self) -> Tensor:
"""
Return the output tensor (concept: cutlass.backend.evt.ir.tensor)
Return the output tensor (concept: cutlass_cppgen.backend.evt.ir.tensor)
"""
return self._store_tensor

View File

@@ -36,7 +36,7 @@ High-level class for tensor
from cutlass_library import LayoutType
from cutlass.backend.evt.ir.layout_algorithm import (
from cutlass_cppgen.backend.evt.ir.layout_algorithm import (
Layout,
broadcast,
canonicalization,
@@ -44,7 +44,7 @@ from cutlass.backend.evt.ir.layout_algorithm import (
reshape,
_reverse_tuple
)
from cutlass.utils.datatypes import get_datatype_and_layout, get_tensor_shape, library_type
from cutlass_cppgen.utils.datatypes import get_datatype_and_layout, get_tensor_shape, library_type
class Tensor:

View File

@@ -30,13 +30,13 @@
#
#################################################################################################
from cutlass.backend.evt.passes.graph_drawer import EVTGraphDrawer
from cutlass.backend.evt.passes.pass_argument_type import PassGetArgumentType
from cutlass.backend.evt.passes.pass_dag_2_tree import PassDAG2Tree
from cutlass.backend.evt.passes.pass_get_impl import PassGetImpl
from cutlass.backend.evt.passes.pass_fix_element_d import PassFixElementD
from cutlass.backend.evt.passes.pass_layout_elimination import PassLayoutManipulateElimination
from cutlass.backend.evt.passes.pass_manager import EVTPassManager
from cutlass.backend.evt.passes.pass_preprocess_red import PassPreprocessRed
from cutlass.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
from cutlass.backend.evt.passes.smem_size_calculator import GetSmemSize
from cutlass_cppgen.backend.evt.passes.graph_drawer import EVTGraphDrawer
from cutlass_cppgen.backend.evt.passes.pass_argument_type import PassGetArgumentType
from cutlass_cppgen.backend.evt.passes.pass_dag_2_tree import PassDAG2Tree
from cutlass_cppgen.backend.evt.passes.pass_get_impl import PassGetImpl
from cutlass_cppgen.backend.evt.passes.pass_fix_element_d import PassFixElementD
from cutlass_cppgen.backend.evt.passes.pass_layout_elimination import PassLayoutManipulateElimination
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassManager
from cutlass_cppgen.backend.evt.passes.pass_preprocess_red import PassPreprocessRed
from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
from cutlass_cppgen.backend.evt.passes.smem_size_calculator import GetSmemSize

View File

@@ -35,7 +35,7 @@ import subprocess
from cutlass_library import DataTypeTag
from cutlass.backend.evt.ir.dag_ir import DAGIR
from cutlass_cppgen.backend.evt.ir.dag_ir import DAGIR
_COLOR_MAP = {

View File

@@ -34,12 +34,12 @@
Construct the epilogue visitor argument type
"""
from cutlass.backend.c_types import visitor_factory
from cutlass.backend.evt.ir import TopoVisitorNode
from cutlass.backend.evt.passes.pass_dag_2_tree import PassDAG2Tree
from cutlass.backend.evt.passes.pass_get_impl import PassGetImpl
from cutlass.backend.evt.passes.pass_manager import EVTPassBase
from cutlass.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
from cutlass_cppgen.backend.c_types import visitor_factory
from cutlass_cppgen.backend.evt.ir import TopoVisitorNode
from cutlass_cppgen.backend.evt.passes.pass_dag_2_tree import PassDAG2Tree
from cutlass_cppgen.backend.evt.passes.pass_get_impl import PassGetImpl
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
class PassGetArgumentType(EVTPassBase):

View File

@@ -37,10 +37,10 @@ by the topological visitor, while the rest of the graph will be implemented with
from copy import deepcopy
from cutlass.backend.evt.ir import DAGIR, TopoVisitorNode
from cutlass.backend.evt.passes.pass_get_impl import PassGetImpl
from cutlass.backend.evt.passes.pass_manager import EVTPassBase
from cutlass.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
from cutlass_cppgen.backend.evt.ir import DAGIR, TopoVisitorNode
from cutlass_cppgen.backend.evt.passes.pass_get_impl import PassGetImpl
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
class PassDAG2Tree(EVTPassBase):

View File

@@ -37,8 +37,8 @@ In Sm90 epilogue visitor, the node writing D to gmem does not have internal
element converter, so the compute node producing D must have element_output = type(D).
"""
from cutlass.backend.evt.passes.pass_layout_elimination import PassLayoutManipulateElimination
from cutlass.backend.evt.passes.pass_manager import EVTPassBase
from cutlass_cppgen.backend.evt.passes.pass_layout_elimination import PassLayoutManipulateElimination
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
class PassFixElementD(EVTPassBase):

View File

@@ -39,13 +39,13 @@ on their layout. For instance, a LoadNode can be AuxLoad, Row/Col/Scalar broadca
This pass infers the underlying impl of each node
"""
import cutlass.backend.evt.backend as evt_backend
from cutlass.backend.evt.ir import DAGIR, LoadNode
from cutlass.backend.evt.passes.pass_fix_element_d import PassFixElementD
from cutlass.backend.evt.passes.pass_manager import EVTPassBase
from cutlass.backend.evt.passes.pass_no_op_elimination import PassNoOpElimination
from cutlass.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
from cutlass.backend.evt.passes.util import cc_map
import cutlass_cppgen.backend.evt.backend as evt_backend
from cutlass_cppgen.backend.evt.ir import DAGIR, LoadNode
from cutlass_cppgen.backend.evt.passes.pass_fix_element_d import PassFixElementD
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
from cutlass_cppgen.backend.evt.passes.pass_no_op_elimination import PassNoOpElimination
from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
from cutlass_cppgen.backend.evt.passes.util import cc_map
class PassGetImpl(EVTPassBase):

View File

@@ -36,9 +36,9 @@ Eliminate layout manipulation nodes
from copy import deepcopy
from cutlass.backend.evt.ir import DAGIR, LayoutNode
from cutlass.backend.evt.passes.pass_manager import EVTPassBase
from cutlass.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
from cutlass_cppgen.backend.evt.ir import DAGIR, LayoutNode
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
class PassLayoutManipulateElimination(EVTPassBase):

View File

@@ -38,8 +38,8 @@ from typing import Any
import networkx as nx
from cutlass.backend.evt.ir import DAGIR
from cutlass.backend.evt.passes.util import cc_map
from cutlass_cppgen.backend.evt.ir import DAGIR
from cutlass_cppgen.backend.evt.passes.util import cc_map
class EVTPassBase:

View File

@@ -36,8 +36,8 @@ No op elimination node
from typing import Any
from cutlass.backend.evt.ir import NoOpImpl
from cutlass.backend.evt.passes.pass_manager import EVTPassBase
from cutlass_cppgen.backend.evt.ir import NoOpImpl
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
class PassNoOpElimination(EVTPassBase):

View File

@@ -38,8 +38,8 @@ This pass fuses these into a single store node, and then replaces all uses of th
current node with the new store node.
"""
from cutlass.backend.evt.ir import ComputeNode, StoreNode
from cutlass.backend.evt.passes.pass_manager import EVTPassBase
from cutlass_cppgen.backend.evt.ir import ComputeNode, StoreNode
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
class PassPreprocessRed(EVTPassBase):

View File

@@ -34,9 +34,9 @@
Shape and type propagation pass
"""
from cutlass.backend.evt.ir.node import NodeBase
from cutlass.backend.evt.passes.pass_manager import EVTPassBase
from cutlass.backend.evt.passes.pass_preprocess_red import PassPreprocessRed
from cutlass_cppgen.backend.evt.ir.node import NodeBase
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
from cutlass_cppgen.backend.evt.passes.pass_preprocess_red import PassPreprocessRed
class PassShapeTypePropagation(EVTPassBase):

View File

@@ -37,9 +37,9 @@ Compute the shared memory size in bytes
import cutlass_library
from pycute import shape_div, product
import cutlass
from cutlass.backend.evt.ir import TopoVisitorNode, DAGIR
from cutlass.backend.library import DataTypeSize
import cutlass_cppgen
from cutlass_cppgen.backend.evt.ir import TopoVisitorNode, DAGIR
from cutlass_cppgen.backend.library import DataTypeSize
class GetSmemSize:

View File

@@ -31,12 +31,12 @@
#################################################################################################
from __future__ import annotations
from cutlass.utils.lazy_import import lazy_import
from cutlass_cppgen.utils.lazy_import import lazy_import
cuda = lazy_import("cuda.cuda")
import numpy as np
from cutlass.backend.memory_manager import device_mem_alloc, todevice
from cutlass.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_tensor
from cutlass_cppgen.backend.memory_manager import device_mem_alloc, todevice
from cutlass_cppgen.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_tensor
class NumpyFrontend:

View File

@@ -35,7 +35,7 @@ import copy
import ctypes
import enum
from cutlass.utils.lazy_import import lazy_import
from cutlass_cppgen.utils.lazy_import import lazy_import
cuda = lazy_import("cuda.cuda")
cudart = lazy_import("cuda.cudart")
from cutlass_library import SubstituteTemplate
@@ -74,8 +74,8 @@ from cutlass_library import (
TileSchedulerType,
get_complex_from_real
)
from cutlass.backend.arguments import ArgumentBase
from cutlass.backend.c_types import (
from cutlass_cppgen.backend.arguments import ArgumentBase
from cutlass_cppgen.backend.c_types import (
GemmCoord_,
GemmCoordBatched_,
GenericMainloopArguments3x_,
@@ -88,7 +88,7 @@ from cutlass.backend.c_types import (
get_mainloop_arguments_3x,
get_tile_scheduler_arguments_3x,
)
from cutlass.backend.library import (
from cutlass_cppgen.backend.library import (
ApiVersion,
EmissionType,
SchedulerMode,
@@ -97,11 +97,11 @@ from cutlass.backend.library import (
TileDescription,
api_version,
)
from cutlass.backend.memory_manager import device_mem_alloc, todevice
from cutlass.backend.operation import ExecutableOperation, LaunchConfiguration
from cutlass.backend.type_hint import GemmOperation, Tensor
from cutlass.backend.utils.device import device_sm_count
from cutlass.shape import GemmCoord, MatrixCoord
from cutlass_cppgen.backend.memory_manager import device_mem_alloc, todevice
from cutlass_cppgen.backend.operation import ExecutableOperation, LaunchConfiguration
from cutlass_cppgen.backend.type_hint import GemmOperation, Tensor
from cutlass_cppgen.backend.utils.device import device_sm_count
from cutlass_cppgen.shape import GemmCoord, MatrixCoord
################################################################################
@@ -116,9 +116,9 @@ def leading_dimension(layout: LayoutType, shape: MatrixCoord) -> int:
Returns the leading dimenson of a tensor with layout ``layout`` and shape ``shape``.
:param layout: layout of the tensor
:type layout: cutlass.shape.LayoutType
:type layout: cutlass_cppgen.shape.LayoutType
:param shape: shape of the tensor
:type shape: cutlass.shape.MatrixCoord
:type shape: cutlass_cppgen.shape.MatrixCoord
:return: leading dimension of the tensor
:rtype: int
@@ -144,11 +144,11 @@ class GemmArguments2x(ArgumentBase):
user-provide tensors into the kernel's argument
:param operation: the GEMM operation to take the argument
:type operation: :class:`cutlass.backend.GemmOperationUniversal` |
:class:`cutlass.backend.GemmOperationGrouped`
:type operation: :class:`cutlass_cppgen.backend.GemmOperationUniversal` |
:class:`cutlass_cppgen.backend.GemmOperationGrouped`
:param problem_size: GEMM problem size gemm(M, N, K)
:type operation: :class:`cutlass.shape.GemmCoord`
:type operation: :class:`cutlass_cppgen.shape.GemmCoord`
:param A: tensor A
:type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
@@ -166,7 +166,7 @@ class GemmArguments2x(ArgumentBase):
:type gemm_mode: :class:`cutlass_library.GemmUniversalMode`
:param output_op: output operator, optional
:type output_op: :class:`cutlass.backend.LinearCombinationFunctorArguments`
:type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments`
:param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
:type stream: :class:`cuda.cuda.CUstream`
@@ -371,11 +371,11 @@ class GemmArguments2xStreamK(GemmArguments2x):
user-provide tensors into the kernel's argument
:param operation: the GEMM operation to take the argument
:type operation: :class:`cutlass.backend.GemmOperationUniversal` |
:class:`cutlass.backend.GemmOperationGrouped`
:type operation: :class:`cutlass_cppgen.backend.GemmOperationUniversal` |
:class:`cutlass_cppgen.backend.GemmOperationGrouped`
:param problem_size: GEMM problem size gemm(M, N, K)
:type operation: :class:`cutlass.shape.GemmCoord`
:type operation: :class:`cutlass_cppgen.shape.GemmCoord`
:param A: tensor A
:type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
@@ -393,7 +393,7 @@ class GemmArguments2xStreamK(GemmArguments2x):
:type gemm_mode: :class:`cutlass_library.GemmUniversalMode`
:param output_op: output operator, optional
:type output_op: :class:`cutlass.backend.LinearCombinationFunctorArguments`
:type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments`
"""
def __init__(self, operation, problem_size, A, B, C, D, gemm_mode=GemmUniversalMode.Gemm, **kwargs):
@@ -483,11 +483,11 @@ class GemmArguments3x(GemmArguments2x):
user-provide tensors into the kernel's argument
:param operation: the GEMM operation to take the argument
:type operation: :class:`cutlass.backend.GemmOperationUniversal` |
:class:`cutlass.backend.GemmOperationGrouped`
:type operation: :class:`cutlass_cppgen.backend.GemmOperationUniversal` |
:class:`cutlass_cppgen.backend.GemmOperationGrouped`
:param problem_size: GEMM problem size gemm(M, N, K)
:type operation: :class:`cutlass.shape.GemmCoord`
:type operation: :class:`cutlass_cppgen.shape.GemmCoord`
:param A: tensor A
:type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
@@ -505,7 +505,7 @@ class GemmArguments3x(GemmArguments2x):
:type gemm_mode: GemmUniversalMode
:param output_op: output operator, optional
:type output_op: :class:`cutlass.backend.LinearCombinationFunctorArguments`
:type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments`
"""
def __init__(self, operation, problem_size, A, B, C, D, gemm_mode=GemmUniversalMode.Gemm, **kwargs):
@@ -631,11 +631,11 @@ def GemmArguments(operation, problem_size, A, B, C, D, gemm_mode=GemmUniversalMo
or 3x arguments depending on the `arch` field specified in `operation`.
:param operation: the GEMM operation to take the argument
:type operation: :class:`cutlass.backend.GemmOperationUniversal` |
:class:`cutlass.backend.GemmOperationGrouped`
:type operation: :class:`cutlass_cppgen.backend.GemmOperationUniversal` |
:class:`cutlass_cppgen.backend.GemmOperationGrouped`
:param problem_size: GEMM problem size gemm(M, N, K)
:type operation: :class:`cutlass.shape.GemmCoord`
:type operation: :class:`cutlass_cppgen.shape.GemmCoord`
:param A: tensor A
:type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
@@ -653,7 +653,7 @@ def GemmArguments(operation, problem_size, A, B, C, D, gemm_mode=GemmUniversalMo
:type gemm_mode: :class:`cutlass_library.GemmUniversalMode`
:param output_op: output operator, optional
:type output_op: :class:`cutlass.backend.LinearCombinationFunctorArguments`
:type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments`
"""
if operation.swizzling_functor == SwizzlingFunctor.StreamK:
if operation.api == ApiVersion.v3x:
@@ -670,10 +670,10 @@ class GemmGroupedArguments:
user-provide tensors into the kernel's argument
:param operation: the GEMM Grouped operation to take the argument
:type operation: :class:`cutlass.backend.GemmOperationGrouped`
:type operation: :class:`cutlass_cppgen.backend.GemmOperationGrouped`
:param problem_size: list of GEMM problem size gemm(M, N, K)
:type operation: list[:class:`cutlass.shape.GemmCoord`]
:type operation: list[:class:`cutlass_cppgen.shape.GemmCoord`]
:param A: list of tensor A
:type A: list[cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray]
@@ -688,7 +688,7 @@ class GemmGroupedArguments:
:type D: list[cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray]
:param output_op: output operator, optional
:type output_op: :class:`cutlass.backend.LinearCombinationFunctorArguments`
:type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments`
:param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
:type stream: :class:`cuda.cuda.CUstream`

View File

@@ -417,7 +417,7 @@ def CalculateSmemUsagePerStage(operation):
:param op: operation for which the maximum stages should be computed. If stages are
set via the `op.tile_description.stages` parameter, this setting is ignored
in the present calculation
:type op: cutlass.backend.Operation
:type op: cutlass_cppgen.backend.Operation
:return: number of bytes of shared memory consumed by a single stage
:rtype: int
@@ -442,7 +442,7 @@ def CalculateSmemUsage(operation):
:param op: operation for which the maximum stages should be computed. If stages are
set via the `op.tile_description.stages` parameter, this setting is ignored
in the present calculation
:type op: cutlass.backend.Operation
:type op: cutlass_cppgen.backend.Operation
:return: int
"""

View File

@@ -32,11 +32,11 @@
import numpy as np
import cutlass
from cutlass.utils.datatypes import is_numpy_tensor
from cutlass.utils.lazy_import import lazy_import
import cutlass_cppgen
from cutlass_cppgen.utils.datatypes import is_numpy_tensor
from cutlass_cppgen.utils.lazy_import import lazy_import
if cutlass.use_rmm:
if cutlass_cppgen.use_rmm:
import rmm
else:
cudart = lazy_import("cuda.cudart")
@@ -73,7 +73,7 @@ def _todevice(host_data):
"""
Helper for transferring host data to device memory
"""
if cutlass.use_rmm:
if cutlass_cppgen.use_rmm:
return rmm.DeviceBuffer.to_device(host_data.tobytes())
else:
nbytes = len(host_data.tobytes())
@@ -100,7 +100,7 @@ def todevice(host_data, dtype=np.float32):
def device_mem_alloc(size):
if cutlass.use_rmm:
if cutlass_cppgen.use_rmm:
return rmm.DeviceBuffer(size=size)
else:
err, ptr = cudart.cudaMalloc(size)
@@ -114,7 +114,7 @@ def align_size(size, alignment=256):
def create_memory_pool(init_pool_size=0, max_pool_size=2 ** 34):
if cutlass.use_rmm:
if cutlass_cppgen.use_rmm:
memory_pool = PoolMemoryManager(init_pool_size=init_pool_size, max_pool_size=max_pool_size)
return memory_pool
else:

View File

@@ -31,10 +31,10 @@
#################################################################################################
import ctypes
from cutlass.utils.lazy_import import lazy_import
from cutlass_cppgen.utils.lazy_import import lazy_import
cuda = lazy_import("cuda.cuda")
from cutlass.backend.utils.device import device_cc
from cutlass_cppgen.backend.utils.device import device_cc
_supports_cluster_launch = None

View File

@@ -34,7 +34,7 @@ from __future__ import annotations
import ctypes
from typing import Union
from cutlass.utils.lazy_import import lazy_import
from cutlass_cppgen.utils.lazy_import import lazy_import
cuda = lazy_import("cuda.cuda")
cudart = lazy_import("cuda.cudart")
import numpy as np
@@ -47,14 +47,14 @@ from cutlass_library import (
SubstituteTemplate
)
import cutlass
from cutlass.backend.c_types import MatrixCoord_, TensorRef2D_, get_reduction_params
from cutlass.backend.frontend import NumpyFrontend, TorchFrontend
from cutlass.backend.library import TensorDescription
from cutlass.backend.memory_manager import DevicePtrWrapper
from cutlass.backend.operation import ExecutableOperation, LaunchConfiguration
from cutlass.shape import MatrixCoord
from cutlass.utils.datatypes import is_numpy_tensor, is_torch_tensor
import cutlass_cppgen
from cutlass_cppgen.backend.c_types import MatrixCoord_, TensorRef2D_, get_reduction_params
from cutlass_cppgen.backend.frontend import NumpyFrontend, TorchFrontend
from cutlass_cppgen.backend.library import TensorDescription
from cutlass_cppgen.backend.memory_manager import DevicePtrWrapper
from cutlass_cppgen.backend.operation import ExecutableOperation, LaunchConfiguration
from cutlass_cppgen.shape import MatrixCoord
from cutlass_cppgen.utils.datatypes import is_numpy_tensor, is_torch_tensor
class ReductionOperation:
@@ -200,7 +200,7 @@ class ReductionArguments:
Frees allocated device-side memory
"""
# Free any device memory allocated manually
if not cutlass.use_rmm:
if not cutlass_cppgen.use_rmm:
for attr in ["destination_buffer", "source_buffer"]:
if hasattr(self, attr):
buf = getattr(self, attr)

View File

@@ -30,4 +30,4 @@
#
################################################################################
from cutlass.backend.utils.device import check_cuda_errors, device_cc
from cutlass_cppgen.backend.utils.device import check_cuda_errors, device_cc

View File

@@ -35,12 +35,12 @@ Utility functions for interacting with the device
"""
from __future__ import annotations
from cutlass.utils.lazy_import import lazy_import
from cutlass_cppgen.utils.lazy_import import lazy_import
cuda = lazy_import("cuda.cuda")
cudart = lazy_import("cuda.cudart")
import cutlass
from cutlass.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_tensor
import cutlass_cppgen
from cutlass_cppgen.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_tensor
def check_cuda_errors(result: list):
@@ -77,7 +77,7 @@ def device_cc(device: int = -1) -> int:
:rtype: int
"""
if device == -1:
device = cutlass.device_id()
device = cutlass_cppgen.device_id()
deviceProp = check_cuda_errors(cudart.cudaGetDeviceProperties(device))
major = str(deviceProp.major)
@@ -87,7 +87,7 @@ def device_cc(device: int = -1) -> int:
def device_sm_count(device: int = -1):
if device == -1:
device = cutlass.device_id()
device = cutlass_cppgen.device_id()
err, device_sm_count = cuda.cuDeviceGetAttribute(
cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device
)

View File

@@ -30,4 +30,4 @@
#
#################################################################################################
from cutlass.emit.pytorch import pytorch
from cutlass_cppgen.emit.pytorch import pytorch

View File

@@ -34,10 +34,10 @@
Common utilities for emitting CUTLASS kernels
"""
import cutlass
import cutlass_cppgen
# Strings used for printing information about the generation of emitted scripts
_AUTOGEN_STR = f"This file was automatically generated by the CUTLASS {cutlass.__version__} Python interface (https://github.com/nvidia/cutlass/python)"
_AUTOGEN_STR = f"This file was automatically generated by the CUTLASS {cutlass_cppgen.__version__} Python interface (https://github.com/nvidia/cutlass/python)"
_CSTYLE_AUTOGEN_COMMENT = f"""// {_AUTOGEN_STR}

View File

@@ -39,9 +39,9 @@ Example usage with JIT compilation:
.. highlight:: python
.. code-block:: python
plan = cutlass.op.Gemm(element=torch.float32, layout=cutlass_library.LayoutType.RowMajor)
plan = cutlass_cppgen.op.Gemm(element=torch.float32, layout=cutlass_library.LayoutType.RowMajor)
op = plan.construct()
mod = cutlass.emit.pytorch(op, 'cutlass_gemm', 80, jit=True)
mod = cutlass_cppgen.emit.pytorch(op, 'cutlass_gemm', 80, jit=True)
# Generate inputs for the GEMM
A, B, C = [torch.ones((512, 512)).to('cuda') for _ in range(3)]
@@ -55,9 +55,9 @@ Example usage without JIT compilation:
.. highlight:: python
.. code-block:: python
plan = cutlass.op.Gemm(element=torch.float32, layout=cutlass.LayoutType.RowMajor)
plan = cutlass_cppgen.op.Gemm(element=torch.float32, layout=cutlass_cppgen.LayoutType.RowMajor)
op = plan.construct()
cutlass.emit.pytorch(op, 'cutlass_gemm', 80, jit=False, sourcedir='output')
cutlass_cppgen.emit.pytorch(op, 'cutlass_gemm', 80, jit=False, sourcedir='output')
After this call, the directory ``output`` contains ``setup.py``,
``cutlass_gemm.cpp``, and ``cutlass_gemm_kernel.cu``. The module can be built from
@@ -83,12 +83,12 @@ import os
from cutlass_library import ConvKind, ConvKindNames, DataType, SubstituteTemplate
from cutlass import CUTLASS_PATH, logger, swizzle
from cutlass.backend.gemm_operation import GemmOperationGrouped, GemmOperationUniversal
from cutlass.backend.conv2d_operation import Conv2dOperation
from cutlass.backend.library import ApiVersion
from cutlass.emit import common
from cutlass.utils.datatypes import is_torch_available
from cutlass_cppgen import CUTLASS_PATH, logger, swizzle
from cutlass_cppgen.backend.gemm_operation import GemmOperationGrouped, GemmOperationUniversal
from cutlass_cppgen.backend.conv2d_operation import Conv2dOperation
from cutlass_cppgen.backend.library import ApiVersion
from cutlass_cppgen.emit import common
from cutlass_cppgen.utils.datatypes import is_torch_available
if is_torch_available():
import torch

View File

@@ -30,7 +30,7 @@
#
#################################################################################################
from cutlass.epilogue.epilogue import (
from cutlass_cppgen.epilogue.epilogue import (
get_activations,
get_activation_epilogue,
gelu,
@@ -44,7 +44,7 @@ from cutlass.epilogue.epilogue import (
trace
)
from cutlass.epilogue.evt_ops import (
from cutlass_cppgen.epilogue.evt_ops import (
max,
multiply_add,
sum,

View File

@@ -39,11 +39,11 @@ code like the following for GEMM:
.. highlight:: python
.. code-block:: python
plan = cutlass.op.Gemm(element=cutlass.DataType.f32, layout=cutlass.LayoutType.RowMajor)
plan.activation = cutlass.epilogue.relu
plan = cutlass_cppgen.op.Gemm(element=cutlass_cppgen.DataType.f32, layout=cutlass_cppgen.LayoutType.RowMajor)
plan.activation = cutlass_cppgen.epilogue.relu
"""
from cutlass.backend import epilogue, device_cc
from cutlass_cppgen.backend import epilogue, device_cc
gelu = epilogue.gelu
@@ -111,7 +111,7 @@ def get_activation_epilogue(
"""
Frontend for EVT that generates epilogue functor through tracing the input function
"""
from cutlass.backend.evt.frontend import PythonASTFrontend
from cutlass_cppgen.backend.evt.frontend import PythonASTFrontend
def trace(fn, example_tensors, **kwargs):
@@ -124,7 +124,7 @@ def trace(fn, example_tensors, **kwargs):
.. hightlight:: python
.. code-block:: python
import cutlass.backend.evt
import cutlass_cppgen.backend.evt
# Define epilogue function as Python callable
def example_fn(accum, C, alpha, beta, gamma):
@@ -142,7 +142,7 @@ def trace(fn, example_tensors, **kwargs):
}
# Generate the epilogue functor
epilogue_visitor = cutlass.epilogue.trace(example_fn, example_inputs)
epilogue_visitor = cutlass_cppgen.epilogue.trace(example_fn, example_inputs)
"""
if callable(fn):
class EpilogueFunctor(PythonASTFrontend):

View File

@@ -36,7 +36,7 @@ Collection of builtin functions used for host reference in EVT
import numpy as np
from cutlass.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_available, is_torch_tensor
from cutlass_cppgen.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_available, is_torch_tensor
if is_torch_available():
import torch

View File

@@ -40,9 +40,9 @@ import logging
import cutlass_library
from cutlass_library.library import ConvKind, IteratorAlgorithm, StrideSupport, GroupMode
import cutlass
from cutlass.utils.check import valid_stage_count
from cutlass.utils.datatypes import td_from_profiler_td, td_from_profiler_op
import cutlass_cppgen
from cutlass_cppgen.utils.check import valid_stage_count
from cutlass_cppgen.utils.datatypes import td_from_profiler_td, td_from_profiler_op
_generator_ccs = [50, 60, 61, 70, 75, 80, 90]
@@ -99,14 +99,14 @@ class KernelsForDataType:
ops.extend(alignment_ops)
return ops
def default_operation(self, math_operation: cutlass.MathOperation):
def default_operation(self, math_operation: cutlass_cppgen.MathOperation):
key = sorted(list(self.kernels_by_alignment.keys()))[0]
kernels = self.kernels_by_alignment[key]
if math_operation is not None:
kernels = [x for x in kernels if x.tile_description.math_instruction.math_operation == math_operation]
return kernels[0]
def operations(self, alignment_A: int, alignment_B: int, alignment_C: int, math_operation: cutlass.MathOperation):
def operations(self, alignment_A: int, alignment_B: int, alignment_C: int, math_operation: cutlass_cppgen.MathOperation):
"""
Returns operations satisfying the alignment constraints
@@ -117,7 +117,7 @@ class KernelsForDataType:
:param alignment_C: alignment constraint of operations to return
:type alignment_C: int
:param math_operation: math operation to consider
:type math_operation: cutlass.MathOperation
:type math_operation: cutlass_cppgen.MathOperation
:return: list of operations
:rtype: list
@@ -158,14 +158,14 @@ class KernelsForDataType:
return operand_list.index(key)
def find_alignment(self, shape: tuple, layout: cutlass.LayoutType, operand=str) -> int:
def find_alignment(self, shape: tuple, layout: cutlass_cppgen.LayoutType, operand=str) -> int:
"""
Returns the most preferable alignment for a given shape and layout
:param shape: extent of each dimension of the tensor
:type shape: tuple
:param layout: layout of the tensor
:type layout: cutlass.LayoutType
:type layout: cutlass_cppgen.LayoutType
:param operand: descriptor of the operand in question
:type operand: str
@@ -175,11 +175,11 @@ class KernelsForDataType:
operand_idx = self._operand_idx(operand)
# Determine the leading dimension of the shape
if layout == cutlass.LayoutType.ColumnMajor:
if layout == cutlass_cppgen.LayoutType.ColumnMajor:
ld = shape[-2]
elif layout == cutlass.LayoutType.RowMajor:
elif layout == cutlass_cppgen.LayoutType.RowMajor:
ld = shape[-1]
elif layout == cutlass.LayoutType.TensorNHWC:
elif layout == cutlass_cppgen.LayoutType.TensorNHWC:
ld = shape[-1]
else:
raise Exception(f"Unexpected or unsupported layout {layout}")
@@ -204,12 +204,12 @@ class KernelsForDataType:
for alignment in self.kernels_by_alignment.keys():
self.kernels_by_alignment[alignment].sort(key=key, reverse=True)
def supports_math_operation(self, math_operation: cutlass.MathOperation) -> bool:
def supports_math_operation(self, math_operation: cutlass_cppgen.MathOperation) -> bool:
"""
Returns whether `math_operation` is supported by at least one operation.
:param math_operation: math operation to consider
:type math_operation: cutlass.MathOperation
:type math_operation: cutlass_cppgen.MathOperation
:return: whether math_operation is supported by at least one operation
:rtype: bool
@@ -262,7 +262,7 @@ class ArchOptions:
# descriptions for the target CC
generate_function_name = "GenerateSM" + str(kernel_cc)
if not hasattr(cutlass_library.generator, generate_function_name):
cutlass.logger.warning(f"No generator found for architecture {kernel_cc}")
cutlass_cppgen.logger.warning(f"No generator found for architecture {kernel_cc}")
return
generate_function = getattr(cutlass_library.generator, generate_function_name)
@@ -270,16 +270,16 @@ class ArchOptions:
# for the target CC
args = [
"--kernels=all",
f"--log-level={logging.getLevelName(cutlass.logger.level)}"
f"--log-level={logging.getLevelName(cutlass_cppgen.logger.level)}"
]
manifest_args = cutlass_library.generator.define_parser().parse_args(args)
manifest = cutlass_library.manifest.Manifest(manifest_args)
generate_function(manifest, cutlass._nvcc_version)
generate_function(manifest, cutlass_cppgen._nvcc_version)
if operation_kind not in manifest.operations:
# No kernels generated for this architecture, this could be because the CUDA
# toolkit is insufficient to support operations in this CC
cutlass.logger.warning(f"No operations of type {operation_kind} found for CC {kernel_cc}")
cutlass_cppgen.logger.warning(f"No operations of type {operation_kind} found for CC {kernel_cc}")
return
# Only one CC should be returned, given the setup above of calling only the generation scripts
@@ -358,7 +358,7 @@ class ArchOptions:
# Add FP8 A/B with FP32 C
for type_comb in combinations_with_replacement(fp8_types, 2):
types.append(type_comb + (cutlass.DataType.f32,))
types.append(type_comb + (cutlass_cppgen.DataType.f32,))
layouts = [
(cutlass_library.LayoutType.RowMajor, cutlass_library.LayoutType.RowMajor),
@@ -444,7 +444,7 @@ class ArchOptions:
:param layout_comb: tuple of data types for (layout_A, layout_B)
:type layout_comb: tuple[cutlass_library.LayoutType]
:param math_operation: math operation to consider or None if any can be considered
:type math_operation: cutlass.MathOperation
:type math_operation: cutlass_cppgen.MathOperation
:return: set of operation classes that support the provided data type and layout combination
:rtype: set
@@ -484,7 +484,7 @@ class ArchOptions:
:param layout_b: layout of operand B
:type layout_b: cutlass_library.LayoutType
:param math_operation: math operation to consider
:type math_operation: cutlass.MathOperation
:type math_operation: cutlass_cppgen.MathOperation
:return: set of operation classes that support the provided data type combination
:rtype: set
@@ -524,7 +524,7 @@ class ArchOptions:
:param layout_b: layout of operand B
:type layout_b: cutlass_library.LayoutType
:param math_operation: math operation to consider
:type math_operation: cutlass.MathOperation
:type math_operation: cutlass_cppgen.MathOperation
:return: container of kernels by alignment supported by the provided combination of parameters
:rtype: KernelsForDataType

View File

@@ -30,7 +30,7 @@
#
#################################################################################################
from cutlass.op.conv import Conv2d, Conv2dFprop, Conv2dDgrad, Conv2dWgrad
from cutlass.op.gemm import Gemm
from cutlass.op.gemm_grouped import GroupedGemm
from cutlass.op.op import OperationBase
from cutlass_cppgen.op.conv import Conv2d, Conv2dFprop, Conv2dDgrad, Conv2dWgrad
from cutlass_cppgen.op.gemm import Gemm
from cutlass_cppgen.op.gemm_grouped import GroupedGemm
from cutlass_cppgen.op.op import OperationBase

View File

@@ -47,7 +47,7 @@
.. code-block:: python
# A, B, C, and D are torch/numpy/cupy tensor objects
plan = cutlass.op.Conv(A, B, C, D)
plan = cutlass_cppgen.op.Conv(A, B, C, D)
plan.run(stride=(1, 1), padding=(0, 0), dilation=(1, 1))
One can also use the interface by specifying data types of operands at construction
@@ -57,11 +57,11 @@
.. code-block:: python
# The following is shorthand for:
# cutlass.op.Conv2d(kind="fprop",
# cutlass_cppgen.op.Conv2d(kind="fprop",
# element_A=torch.float32, element_B=torch.float32,
# element_C=torch.float32, element_D=torch.float32,
# element_accumulator=torch.float32)
plan = cutlass.op.Conv2d(kind="fprop", element=torch.float32)
plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=torch.float32)
A0 = torch.rand((128, 256), dtype=torch.float32, device='cuda')
B0 = torch.rand((256, 64), dtype=torch.float32, device='cuda')
@@ -81,7 +81,7 @@
.. highlight:: python
.. code-block:: python
plan = cutlass.op.Conv2d(kind="fprop", element=np.float32)
plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=np.float32)
# Do other work...
@@ -96,15 +96,15 @@
.. highlight:: python
.. code-block:: python
plan = cutlass.op.Conv2d(kind="fprop", element=np.float32)
plan.activation = cutlass.epilogue.relu
plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=np.float32)
plan.activation = cutlass_cppgen.epilogue.relu
Operations can also be run asynchronously:
.. highlight:: python
.. code-block:: python
plan = cutlass.op.Conv2d(kind="fprop", element=np.float32)
plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=np.float32)
args = plan.run()
# Do other work...
@@ -114,7 +114,7 @@
from __future__ import annotations
from typing import Optional
from cutlass.utils.lazy_import import lazy_import
from cutlass_cppgen.utils.lazy_import import lazy_import
cuda = lazy_import("cuda.cuda")
cudart = lazy_import("cuda.cudart")
from cutlass_library import (
@@ -127,15 +127,15 @@ from cutlass_library import (
StrideSupport,
)
import cutlass
from cutlass import epilogue
from cutlass.backend import compiler
from cutlass.backend.conv2d_operation import Conv2dArguments, Conv2dOperation
from cutlass.backend.reduction_operation import ReductionOperation, ReductionArguments
from cutlass.backend.library import TensorDescription, TileDescription
from cutlass.op.op import OperationBase
from cutlass.shape import Conv2DProblemSize, MatrixCoord
from cutlass.utils import check, datatypes
import cutlass_cppgen
from cutlass_cppgen import epilogue
from cutlass_cppgen.backend import compiler
from cutlass_cppgen.backend.conv2d_operation import Conv2dArguments, Conv2dOperation
from cutlass_cppgen.backend.reduction_operation import ReductionOperation, ReductionArguments
from cutlass_cppgen.backend.library import TensorDescription, TileDescription
from cutlass_cppgen.op.op import OperationBase
from cutlass_cppgen.shape import Conv2DProblemSize, MatrixCoord
from cutlass_cppgen.utils import check, datatypes
class Conv2d(OperationBase):
@@ -155,11 +155,11 @@ class Conv2d(OperationBase):
# Use F32 for A, B, C, D, and accumulation in fprop
# Use the generic ``element`` parameter to concisely set all data types for operands to the same values.
Conv2d(kind="fprop", element=cutlass.DataType.f32)
Conv2d(kind="fprop", element=cutlass_cppgen.DataType.f32)
# Explicitly specify the data types to use for A, B, C, and D.
Conv2d(kind="fprop", element_A=cutlass.DataType.f32, element_B=cutlass.DataType.f32,
element_C=cutlass.DataType.f32, element_D=cutlass.DataType.f32)
Conv2d(kind="fprop", element_A=cutlass_cppgen.DataType.f32, element_B=cutlass_cppgen.DataType.f32,
element_C=cutlass_cppgen.DataType.f32, element_D=cutlass_cppgen.DataType.f32)
# Set the data types and elements from existing tensors. Note that one can use different tensors when
# executing GEMM via the ``run()`` method than passed in here (though those passed in to ``run()`` must
@@ -169,8 +169,8 @@ class Conv2d(OperationBase):
# Explicitly specify the data type for only some of A, B, C, and D. Unspecified data types will inherit
# those passed in via the generic ``element``
Conv2d(kind="fprop", element_A=cutlass.DataType.f32, element_accumulator=cutlass.DataType.f32,
element=cutlass.DataType.f32)
Conv2d(kind="fprop", element_A=cutlass_cppgen.DataType.f32, element_accumulator=cutlass_cppgen.DataType.f32,
element=cutlass_cppgen.DataType.f32)
The order of precedence for the setting of the data type for a given operand/output is as follows:
1) If the tensor type is specified (e.g., ``A``), use the data type inferred from this tensor
@@ -186,17 +186,17 @@ class Conv2d(OperationBase):
:param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
:param beta: scalar parameter beta from GEMM operation that scales operand C
:param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type
:type element: cutlass.DataType
:type element: cutlass_cppgen.DataType
:param element_A: data type to be used for operand A
:type element_A: cutlass.DataType
:type element_A: cutlass_cppgen.DataType
:param element_B: data type to be used for operand B
:type element_B: cutlass.DataType
:type element_B: cutlass_cppgen.DataType
:param element_C: data type to be used for operand C
:type element_C: cutlass.DataType
:type element_C: cutlass_cppgen.DataType
:param element_D: data type to be used for operand D
:type element_D: cutlass.DataType
:type element_D: cutlass_cppgen.DataType
:param element_accumulator: data type to be used in accumulation of the product of operands A and B
:type element_accumulator: cutlass.DataType
:type element_accumulator: cutlass_cppgen.DataType
:param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90
:type cc: int
:param kernel_cc: compute capability of kernels to generate. For example, if running on SM90, but desiring to use a CUTLASS 2.x-style Ampere kernel, this should be set to 80
@@ -215,7 +215,7 @@ class Conv2d(OperationBase):
if self.current_cc == 90:
# The Conv2d kernel on Hopper (SM90) is currently unsupported
# Revert to use SM80-tagged kernels
cutlass.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.")
cutlass_cppgen.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.")
self.specified_kernel_cc = 80
self._reset_options(80)
@@ -250,7 +250,7 @@ class Conv2d(OperationBase):
assert elt_to_set is not None
# Currently we only support layout TensorNHWC
lay_to_set = cutlass.LayoutType.TensorNHWC
lay_to_set = cutlass_cppgen.LayoutType.TensorNHWC
elements.append(datatypes.library_type(elt_to_set))
layouts.append(lay_to_set)
@@ -301,10 +301,10 @@ class Conv2d(OperationBase):
self._layout_a, self._layout_b, self._math_operation
)
if cutlass.OpcodeClass.TensorOp in self.possible_op_classes:
self.opclass = cutlass.OpcodeClass.TensorOp
elif cutlass.OpcodeClass.Simt in self.possible_op_classes:
self.opclass = cutlass.OpcodeClass.Simt
if cutlass_cppgen.OpcodeClass.TensorOp in self.possible_op_classes:
self.opclass = cutlass_cppgen.OpcodeClass.TensorOp
elif cutlass_cppgen.OpcodeClass.Simt in self.possible_op_classes:
self.opclass = cutlass_cppgen.OpcodeClass.Simt
else:
if self._math_operation is not None:
math_op_str = f' and math operation {self._math_operation}'
@@ -342,7 +342,7 @@ class Conv2d(OperationBase):
Set the tile description
:param td: tile description
:type td: cutlass.backend.TileDescription, or a dict with keys
:type td: cutlass_cppgen.backend.TileDescription, or a dict with keys
{
"threadblock_shape": [int, int, int],
"warp_count": [int, int, int],
@@ -359,7 +359,7 @@ class Conv2d(OperationBase):
self._tile_description = datatypes.td_from_profiler_op(op)
if "cluster_shape" in td.keys():
if td["cluster_shape"] != [1, 1, 1]:
cutlass.logger.warning("Conv2d currently only support 'cluster_shape'=[1, 1, 1]'.")
cutlass_cppgen.logger.warning("Conv2d currently only support 'cluster_shape'=[1, 1, 1]'.")
td["cluster_shape"] = [1, 1, 1]
td = self._tile_description.clone_and_update(td)
@@ -381,7 +381,7 @@ class Conv2d(OperationBase):
- Is the kernel schedule being used supported on the architecture in question?
:param td: tile description to validate
:type td: cutlass.backend.TileDescription
:type td: cutlass_cppgen.backend.TileDescription
:return: tuple in which the first element is a bool indicating that the tile description is valid
and the second element is a string providing an optional error message.
:rtype: tuple
@@ -445,9 +445,9 @@ class Conv2d(OperationBase):
"""
if self.conv_kind == ConvKind.Dgrad:
if stride[0] != 1 or stride[1] != 1:
return getattr(cutlass.swizzle, f"StridedDgradIdentitySwizzle{self._swizzling_stride}")
return getattr(cutlass_cppgen.swizzle, f"StridedDgradIdentitySwizzle{self._swizzling_stride}")
return getattr(cutlass.swizzle, f"IdentitySwizzle{self._swizzling_stride}")
return getattr(cutlass_cppgen.swizzle, f"IdentitySwizzle{self._swizzling_stride}")
#
# Iterator Algorithm Related
@@ -546,14 +546,14 @@ class Conv2d(OperationBase):
self, tile_description: TileDescription = None,
alignment_A: int = None, alignment_B: int = None, alignment_C: int = None,
iterator_algorithm: IteratorAlgorithm = None,
stride_support = None, swizzling_functor: cutlass.swizzle = None,
epilogue_functor=None) -> cutlass.backend.Conv2dOperation:
stride_support = None, swizzling_functor: cutlass_cppgen.swizzle = None,
epilogue_functor=None) -> cutlass_cppgen.backend.Conv2dOperation:
"""
Constructs a ``cutlass.backend.Conv2dOperation`` based on the input parameters and current
Constructs a ``cutlass_cppgen.backend.Conv2dOperation`` based on the input parameters and current
kernel specification of the ``Conv2d`` object.
:param tile_description: tile description specifying shapes and operand types to use in the kernel
:type tile_description: cutlass.backend.TileDescription
:type tile_description: cutlass_cppgen.backend.TileDescription
:param alignment_A: alignment of operand A
:type alignment_A: int
:param alignment_B: alignment of operand B
@@ -565,11 +565,11 @@ class Conv2d(OperationBase):
:param stride_support: the stride support of dgrad
:type stride_support: cutlass_library.library.StrideSupport
:param swizzling_functor: the swizzling functor
:type swizzling_functor: cutlass.swizzle
:type swizzling_functor: cutlass_cppgen.swizzle
:param epilogue_functor: the epilogue functor
:return: operation that was constructed
:rtype: cutlass.backend.Conv2dOperation
:rtype: cutlass_cppgen.backend.Conv2dOperation
"""
# Get alignment
alignment_A = check.alignment_or_default(alignment_A, self.alignment_pref_A)
@@ -637,8 +637,8 @@ class Conv2d(OperationBase):
def compile(self, tile_description: TileDescription = None,
alignment_A: int = None, alignment_B: int = None, alignment_C: int = None,
iterator_algorithm: IteratorAlgorithm = None,
stride_support = None, swizzling_functor: cutlass.swizzle = None,
epilogue_functor = None, print_module: bool = False) -> cutlass.backend.Conv2dOperation:
stride_support = None, swizzling_functor: cutlass_cppgen.swizzle = None,
epilogue_functor = None, print_module: bool = False) -> cutlass_cppgen.backend.Conv2dOperation:
"""
Emits and compiles the kernel currently specified. If ``tile_description`` and any
of the ``alignment`` parameters are set, the kernel will be chosen using this
@@ -646,7 +646,7 @@ class Conv2d(OperationBase):
will be used.
::param tile_description: tile description specifying shapes and operand types to use in the kernel
:type tile_description: cutlass.backend.TileDescription
:type tile_description: cutlass_cppgen.backend.TileDescription
:param alignment_A: alignment of operand A
:type alignment_A: int
:param alignment_B: alignment of operand B
@@ -658,11 +658,11 @@ class Conv2d(OperationBase):
:param stride_support: the stride support of dgrad
:type stride_support: cutlass_library.library.StrideSupport
:param swizzling_functor: the swizzling functor
:type swizzling_functor: cutlass.swizzle
:type swizzling_functor: cutlass_cppgen.swizzle
:param epilogue_functor: the epilogue functor
:return: operation that was compiled
:rtype: cutlass.backend.Conv2dOperation
:rtype: cutlass_cppgen.backend.Conv2dOperation
"""
self.operation = self.construct(
@@ -770,7 +770,7 @@ class Conv2d(OperationBase):
:type stream: :class:`cuda.cuda.CUstream`
:return: arguments passed in to the kernel
:rtype: cutlass.backend.Conv2dArguments
:rtype: cutlass_cppgen.backend.Conv2dArguments
"""
if not stream:
stream = cuda.CUstream(0)

View File

@@ -47,7 +47,7 @@
.. code-block:: python
# A, B, C, and D are torch/numpy/cupy tensor objects
plan = cutlass.op.Gemm(A, B, C, D)
plan = cutlass_cppgen.op.Gemm(A, B, C, D)
plan.run()
@@ -58,11 +58,11 @@
.. code-block:: python
# The following is shorthand for:
# cutlass.op.Gemm(element_A=torch.float32, element_B=torch.float32,
# cutlass_cppgen.op.Gemm(element_A=torch.float32, element_B=torch.float32,
# element_C=torch.float32, element_D=torch.float32,
# element_accumulator=torch.float32,
# layout=cutlass.LayoutType.RowMajor)
plan = cutlass.op.Gemm(element=torch.float32, layout=cutlass.LayoutType.RowMajor)
# layout=cutlass_cppgen.LayoutType.RowMajor)
plan = cutlass_cppgen.op.Gemm(element=torch.float32, layout=cutlass_cppgen.LayoutType.RowMajor)
A0 = torch.rand((128, 256), device='cuda')
B0 = torch.rand((256, 64), device='cuda')
@@ -82,7 +82,7 @@
.. highlight:: python
.. code-block:: python
plan = cutlass.op.Gemm(element=np.float32, layout=cutlass.LayoutType.RowMajor)
plan = cutlass_cppgen.op.Gemm(element=np.float32, layout=cutlass_cppgen.LayoutType.RowMajor)
plan.compile()
# Do other work...
@@ -98,15 +98,15 @@
.. highlight:: python
.. code-block:: python
plan = cutlass.op.Gemm(element=np.float32, layout=cutlass.LayoutType.RowMajor)
plan.activation = cutlass.epilogue.relu
plan = cutlass_cppgen.op.Gemm(element=np.float32, layout=cutlass_cppgen.LayoutType.RowMajor)
plan.activation = cutlass_cppgen.epilogue.relu
Operations can also be run asynchronously:
.. highlight:: python
.. code-block:: python
plan = cutlass.op.Gemm(element=np.float32, layout=cutlass.LayoutType.RowMajor)
plan = cutlass_cppgen.op.Gemm(element=np.float32, layout=cutlass_cppgen.LayoutType.RowMajor)
args = plan.run()
# Do other work...
@@ -117,7 +117,7 @@ from __future__ import annotations
from typing import Optional
from math import prod
from cutlass.utils.lazy_import import lazy_import
from cutlass_cppgen.utils.lazy_import import lazy_import
cuda = lazy_import("cuda.cuda")
from cutlass_library import (
DataType,
@@ -125,15 +125,15 @@ from cutlass_library import (
GemmUniversalMode,
)
import cutlass
from cutlass import epilogue, swizzle
from cutlass.backend import compiler
from cutlass.backend.evt import EpilogueFunctorVisitor
from cutlass.backend.gemm_operation import GemmArguments, GemmOperationUniversal
from cutlass.backend.library import TensorDescription, TileDescription
from cutlass.op.op import OperationBase
from cutlass.shape import GemmCoord
from cutlass.utils import check, datatypes
import cutlass_cppgen
from cutlass_cppgen import epilogue, swizzle
from cutlass_cppgen.backend import compiler
from cutlass_cppgen.backend.evt import EpilogueFunctorVisitor
from cutlass_cppgen.backend.gemm_operation import GemmArguments, GemmOperationUniversal
from cutlass_cppgen.backend.library import TensorDescription, TileDescription
from cutlass_cppgen.op.op import OperationBase
from cutlass_cppgen.shape import GemmCoord
from cutlass_cppgen.utils import check, datatypes
class Gemm(OperationBase):
@@ -154,11 +154,11 @@ class Gemm(OperationBase):
# Use the generic ``element`` and ``layout`` parameters to concisely set all data types and layouts
# for operands to the same values.
Gemm(element=cutlass.DataType.f32, layout=cutlass.LayoutType.RowMajor)
Gemm(element=cutlass_cppgen.DataType.f32, layout=cutlass_cppgen.LayoutType.RowMajor)
# Explicitly specify the data types to use for A, B, C, and D. Use the generic ``layout``.
Gemm(element_A=cutlass.DataType.f32, element_B=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
element_D=cutlass.DataType.f32, layout=cutlass.LayoutType.RowMajor)
Gemm(element_A=cutlass_cppgen.DataType.f32, element_B=cutlass_cppgen.DataType.f32, element_C=cutlass_cppgen.DataType.f32,
element_D=cutlass_cppgen.DataType.f32, layout=cutlass_cppgen.LayoutType.RowMajor)
# Set the data types and elements from existing tensors. Note that one can use different tensors when
# executing GEMM via the ``run()`` method than passed in here (though those passed in to ``run()`` must
@@ -168,13 +168,13 @@ class Gemm(OperationBase):
# Use the generic ``element`` and explicitly specify the layouts to use for A, B, and C (layout of D is
# the same as that for D, at present)
Gemm(element=cutlass.DataType.f32, layout_A=cutlass.LayoutType.RowMajor,
layout_B=cutlass.LayoutType.RowMajor, layout_C=cutlass.LayoutType.RowMajor)
Gemm(element=cutlass_cppgen.DataType.f32, layout_A=cutlass_cppgen.LayoutType.RowMajor,
layout_B=cutlass_cppgen.LayoutType.RowMajor, layout_C=cutlass_cppgen.LayoutType.RowMajor)
# Explicitly specify the data type and layout for only some of A, B, C, and D. Unspecified data types
# and layouts will inherit those passed in via the generic ``element`` and ``layout``
Gemm(element_A=cutlass.DataType.f32, layout_B=cutlass.LayoutType.RowMajor,
element=cutlass.DataType.f32, layout=cutlass.LayoutType.RowMajor)
Gemm(element_A=cutlass_cppgen.DataType.f32, layout_B=cutlass_cppgen.LayoutType.RowMajor,
element=cutlass_cppgen.DataType.f32, layout=cutlass_cppgen.LayoutType.RowMajor)
The order of precedence for the setting of the data type and layout for a given operand/output is as follows:
1) If the tensor type is specified (e.g., ``A``), use the data type and layout inferred from this tensor
@@ -192,27 +192,27 @@ class Gemm(OperationBase):
:param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
:param beta: scalar parameter beta from GEMM operation that scales operand C
:param element_accumulator: data type to be used in accumulation of the product of operands A and B
:type element_accumulator: cutlass.DataType
:type element_accumulator: cutlass_cppgen.DataType
:param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type
:type element: cutlass.DataType
:type element: cutlass_cppgen.DataType
:param layout: generic layout type to be used for operands A, B, C, and D
:type layout: cutlass.LayoutType
:type layout: cutlass_cppgen.LayoutType
:param element_A: data type to be used for operand A
:type element_A: cutlass.DataType
:type element_A: cutlass_cppgen.DataType
:param element_B: data type to be used for operand B
:type element_B: cutlass.DataType
:type element_B: cutlass_cppgen.DataType
:param element_C: data type to be used for operand C
:type element_C: cutlass.DataType
:type element_C: cutlass_cppgen.DataType
:param element_D: data type to be used for operand D
:type element_D: cutlass.DataType
:type element_D: cutlass_cppgen.DataType
:param layout_A: layout of operand A
:type layout_A: cutlass.LayoutType
:type layout_A: cutlass_cppgen.LayoutType
:param layout_B: layout of operand B
:type layout_B: cutlass.LayoutType
:type layout_B: cutlass_cppgen.LayoutType
:param layout_C: layout of operand C
:type layout_C: cutlass.LayoutType
:type layout_C: cutlass_cppgen.LayoutType
:param layout_D: layout of operand D
:type layout_D: cutlass.LayoutType
:type layout_D: cutlass_cppgen.LayoutType
"""
def __init__(
@@ -278,7 +278,7 @@ class Gemm(OperationBase):
self._reset_operations()
self._swizzling_functor = cutlass.swizzle.IdentitySwizzle1
self._swizzling_functor = cutlass_cppgen.swizzle.IdentitySwizzle1
def _reset_operations(self, reset_epilogue: bool = True):
# Set the default op class
@@ -289,10 +289,10 @@ class Gemm(OperationBase):
self._element_a, self._element_b, self._element_accumulator,
self._layout_a, self._layout_b, self._math_operation)
if cutlass.OpcodeClass.TensorOp in self.possible_op_classes:
self.opclass = cutlass.OpcodeClass.TensorOp
elif cutlass.OpcodeClass.Simt in self.possible_op_classes:
self.opclass = cutlass.OpcodeClass.Simt
if cutlass_cppgen.OpcodeClass.TensorOp in self.possible_op_classes:
self.opclass = cutlass_cppgen.OpcodeClass.TensorOp
elif cutlass_cppgen.OpcodeClass.Simt in self.possible_op_classes:
self.opclass = cutlass_cppgen.OpcodeClass.Simt
else:
if self._math_operation is not None:
math_op_str = f' and math operation {self._math_operation}'
@@ -303,7 +303,7 @@ class Gemm(OperationBase):
f'combination {datatype_comb}x{layout_comb}{math_op_str}')
if reset_epilogue:
self._reset_epilogue_functor_activation(cutlass.epilogue.identity)
self._reset_epilogue_functor_activation(cutlass_cppgen.epilogue.identity)
@property
def swizzling_functor(self):
@@ -319,8 +319,8 @@ class Gemm(OperationBase):
"""
Sets the swizzling functor to the type specified by `swizzling_functor`
"""
if swizzling_functor == cutlass.swizzle.ThreadblockSwizzleStreamK:
if self.op_class == cutlass.OpcodeClass.Simt:
if swizzling_functor == cutlass_cppgen.swizzle.ThreadblockSwizzleStreamK:
if self.op_class == cutlass_cppgen.OpcodeClass.Simt:
raise Exception('ThreadblockSwizzleStreamK is currently only supported with opcode class TensorOp')
if self.current_cc == 90:
@@ -345,7 +345,7 @@ class Gemm(OperationBase):
Set the tile description
:param td: tile description
:type td: cutlass.backend.TileDescription, or a dict with keys
:type td: cutlass_cppgen.backend.TileDescription, or a dict with keys
{
"threadblock_shape": [int, int, int],
"warp_count": [int, int, int],
@@ -380,7 +380,7 @@ class Gemm(OperationBase):
- Is the kernel schedule being used supported on the architecture in question?
:param td: tile description to validate
:type td: cutlass.backend.TileDescription
:type td: cutlass_cppgen.backend.TileDescription
:return: tuple in which the first element is a bool indicating that the tile description is valid
and the second element is a string providing an optional error message.
:rtype: tuple
@@ -412,11 +412,11 @@ class Gemm(OperationBase):
self, tile_description: TileDescription = None,
alignment_A: int = None, alignment_B: int = None, alignment_C: int = None) -> GemmOperationUniversal:
"""
Constructs a ``cutlass.backend.GemmUniversalOperation`` based on the input parameters and current
Constructs a ``cutlass_cppgen.backend.GemmUniversalOperation`` based on the input parameters and current
kernel specification of the ``Gemm`` object.
:param tile_description: tile description specifying shapes and operand types to use in the kernel
:type tile_description: cutlass.backend.TileDescription
:type tile_description: cutlass_cppgen.backend.TileDescription
:param alignment_A: alignment of operand A
:type alignment_A: int
:param alignment_B: alignment of operand B
@@ -425,7 +425,7 @@ class Gemm(OperationBase):
:type alignment_C: int
:return: operation that was constructed
:rtype: cutlass.backend.GemmOperationUniversal
:rtype: cutlass_cppgen.backend.GemmOperationUniversal
"""
alignment_pref_A = min(128 // DataTypeSize[self._element_a], max(self.possible_operations.alignments("A")))
alignment_pref_B = min(128 // DataTypeSize[self._element_b], max(self.possible_operations.alignments("B")))
@@ -471,7 +471,7 @@ class Gemm(OperationBase):
def compile(self, tile_description: TileDescription = None,
alignment_A: int = None, alignment_B: int = None, alignment_C: int = None,
print_module: bool = False) -> cutlass.backend.GemmOperationUniversal:
print_module: bool = False) -> cutlass_cppgen.backend.GemmOperationUniversal:
"""
Emits and compiles the kernel currently specified. If ``tile_description`` and any
of the ``alignment`` parameters are set, the kernel will be chosen using this
@@ -479,7 +479,7 @@ class Gemm(OperationBase):
will be used.
:param tile_description: tile description specifying shapes and operand types to use in the kernel
:type tile_description: cutlass.backend.TileDescription
:type tile_description: cutlass_cppgen.backend.TileDescription
:param alignment_A: alignment of operand A
:type alignment_A: int
:param alignment_B: alignment of operand B
@@ -490,7 +490,7 @@ class Gemm(OperationBase):
:type print_module: bool
:return: operation that was compiled
:rtype: cutlass.backend.GemmOperationUniversal
:rtype: cutlass_cppgen.backend.GemmOperationUniversal
"""
self.operation = self.construct(tile_description, alignment_A, alignment_B, alignment_C)
@@ -566,7 +566,7 @@ class Gemm(OperationBase):
:param D: tensor D
:type D: numpy/cupy/torch array/tensor object
:return: tuple containing the problem size (cutlass.shape.GemmCoord), the GEMM mode (cutlass.GemmUniversalMode), and the batch count (int)
:return: tuple containing the problem size (cutlass_cppgen.shape.GemmCoord), the GEMM mode (cutlass_cppgen.GemmUniversalMode), and the batch count (int)
:rtype: tuple
"""
M, K = A.shape[-2:]
@@ -582,9 +582,9 @@ class Gemm(OperationBase):
# and C are row major. A similar operation can be performed if only B has a nonzero
# batch dimension
if batch_count > 1:
A_row = self._layout_a == cutlass.LayoutType.RowMajor
B_row = self._layout_b == cutlass.LayoutType.RowMajor
C_row = self._layout_c == cutlass.LayoutType.RowMajor
A_row = self._layout_a == cutlass_cppgen.LayoutType.RowMajor
B_row = self._layout_b == cutlass_cppgen.LayoutType.RowMajor
C_row = self._layout_c == cutlass_cppgen.LayoutType.RowMajor
# Consider a Tensor to be batched if its rank is > 2 and
# the product of the modes beyond rank 2 equals our pre-determined batch size.
@@ -652,7 +652,7 @@ class Gemm(OperationBase):
:type stream: :class:`cuda.cuda.CUstream`
:return: arguments passed in to the kernel
:rtype: cutlass.backend.GemmArguments
:rtype: cutlass_cppgen.backend.GemmArguments
"""
if not stream:
stream = cuda.CUstream(0)

View File

@@ -47,27 +47,27 @@
.. code-block:: python
# As, Bs, Cs, and Ds are torch/numpy/cupy tensor objects
plan = cutlass.op.GroupedGemm(element=cutlass.DataType.f16, layout=cutlass.LayoutType.RowMajor)
plan = cutlass_cppgen.op.GroupedGemm(element=cutlass_cppgen.DataType.f16, layout=cutlass_cppgen.LayoutType.RowMajor)
plan.run([A0, A1], [B0, B1], [C0, C1], [D0, D1])
"""
from __future__ import annotations
from typing import Optional
from cutlass_library import DataTypeSize
from cutlass.utils.lazy_import import lazy_import
from cutlass_cppgen.utils.lazy_import import lazy_import
cuda = lazy_import("cuda.cuda")
from cutlass.backend.gemm_operation import (
from cutlass_cppgen.backend.gemm_operation import (
GemmGroupedArguments,
GemmOperationGrouped,
)
from cutlass.backend.library import (
from cutlass_cppgen.backend.library import (
SchedulerMode,
TensorDescription,
TileDescription,
)
from cutlass.op.gemm import Gemm
from cutlass.shape import GemmCoord
from cutlass.utils import check, datatypes
from cutlass_cppgen.op.gemm import Gemm
from cutlass_cppgen.shape import GemmCoord
from cutlass_cppgen.utils import check, datatypes
class GroupedGemm(Gemm):
@@ -90,27 +90,27 @@ class GroupedGemm(Gemm):
:param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
:param beta: scalar parameter beta from GEMM operation that scales operand C
:param element_accumulator: data type to be used in accumulation of the product of operands A and B
:type element_accumulator: cutlass.DataType
:type element_accumulator: cutlass_cppgen.DataType
:param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type
:type element: cutlass.DataType
:type element: cutlass_cppgen.DataType
:param layout: generic layout type to be used for operands A, B, C, and D
:type layout: cutlass.LayoutType
:type layout: cutlass_cppgen.LayoutType
:param element_A: data type to be used for operand A
:type element_A: cutlass.DataType
:type element_A: cutlass_cppgen.DataType
:param element_B: data type to be used for operand B
:type element_B: cutlass.DataType
:type element_B: cutlass_cppgen.DataType
:param element_C: data type to be used for operand C
:type element_C: cutlass.DataType
:type element_C: cutlass_cppgen.DataType
:param element_D: data type to be used for operand D
:type element_D: cutlass.DataType
:type element_D: cutlass_cppgen.DataType
:type layout_A: layout of operand A
:param layout_A: cutlass.LayoutType
:param layout_A: cutlass_cppgen.LayoutType
:type layout_B: layout of operand B
:param layout_B: cutlass.LayoutType
:param layout_B: cutlass_cppgen.LayoutType
:type layout_C: layout of operand C
:param layout_C: cutlass.LayoutType
:param layout_C: cutlass_cppgen.LayoutType
:type layout_D: layout of operand D
:param layout_D: cutlass.LayoutType
:param layout_D: cutlass_cppgen.LayoutType
"""
def __init__(
@@ -151,11 +151,11 @@ class GroupedGemm(Gemm):
alignment_B: int = None,
alignment_C: int = None) -> GemmOperationGrouped:
"""
Constructs a ``cutlass.backend.GemmOperationGrouped`` based on the input parameters and current
Constructs a ``cutlass_cppgen.backend.GemmOperationGrouped`` based on the input parameters and current
kernel specification of the ``Gemm`` object.
:param tile_description: tile description specifying shapes and operand types to use in the kernel
:type tile_description: cutlass.backend.TileDescription
:type tile_description: cutlass_cppgen.backend.TileDescription
:param alignment_A: alignment of operand A
:type alignment_A: int
:param alignment_B: alignment of operand B
@@ -164,7 +164,7 @@ class GroupedGemm(Gemm):
:type alignment_C: int
:return: operation that was constructed
:rtype: cutlass.backend.GemmOperationGrouped
:rtype: cutlass_cppgen.backend.GemmOperationGrouped
"""
alignment_A = check.alignment_or_default(alignment_A, max(self.possible_operations.alignments("A")))
alignment_B = check.alignment_or_default(alignment_B, max(self.possible_operations.alignments("B")))
@@ -225,7 +225,7 @@ class GroupedGemm(Gemm):
:type stream: :class:`cuda.cuda.CUstream`
:return: arguments passed in to the kernel
:rtype: cutlass.backend.GemmGroupedArguments
:rtype: cutlass_cppgen.backend.GemmGroupedArguments
"""
if not stream:
stream = cuda.CUstream(0)

View File

@@ -44,14 +44,14 @@ from cutlass_library import (
SharedMemPerCC
)
import cutlass
from cutlass import get_option_registry
from cutlass.backend.evt import EpilogueFunctorVisitor
from cutlass.backend.utils.device import device_cc
from cutlass.epilogue import get_activations, get_activation_epilogue, identity
from cutlass.library_defaults import KernelsForDataType, _generator_ccs
from cutlass.swizzle import get_swizzling_functors
from cutlass.utils import datatypes, check
import cutlass_cppgen
from cutlass_cppgen import get_option_registry
from cutlass_cppgen.backend.evt import EpilogueFunctorVisitor
from cutlass_cppgen.backend.utils.device import device_cc
from cutlass_cppgen.epilogue import get_activations, get_activation_epilogue, identity
from cutlass_cppgen.library_defaults import KernelsForDataType, _generator_ccs
from cutlass_cppgen.swizzle import get_swizzling_functors
from cutlass_cppgen.utils import datatypes, check
class OperationBase:
@@ -205,19 +205,19 @@ class OperationBase:
return tensor
@property
def opclass(self) -> cutlass.OpcodeClass:
def opclass(self) -> cutlass_cppgen.OpcodeClass:
"""
Returns the opcode class currently in use
:return: opcode class currently in use
:rtype: cutlass.OpcodeClass
:rtype: cutlass_cppgen.OpcodeClass
"""
return self.op_class
@opclass.setter
def opclass(self, oc: cutlass.OpcodeClass):
def opclass(self, oc: cutlass_cppgen.OpcodeClass):
if isinstance(oc, str):
oc = datatypes.getattr_enum(cutlass.OpcodeClass, oc)
oc = datatypes.getattr_enum(cutlass_cppgen.OpcodeClass, oc)
if oc in self.possible_op_classes:
self.op_class = oc
else:
@@ -236,25 +236,25 @@ class OperationBase:
self.epilogue_functor = self._reset_epilogue_functor_alignment(self._elements_per_access(), self.epilogue_functor)
@property
def math_operation(self) -> cutlass.MathOperation:
def math_operation(self) -> cutlass_cppgen.MathOperation:
"""
Returns the math operation currently in use
:return: math operation currently in use
:rtype: cutlass.MathOperation
:rtype: cutlass_cppgen.MathOperation
"""
return self._math_operation
@math_operation.setter
def math_operation(self, mo: cutlass.MathOperation):
def math_operation(self, mo: cutlass_cppgen.MathOperation):
if isinstance(mo, str):
mo = datatypes.getattr_enum(cutlass.MathOperation, mo)
mo = datatypes.getattr_enum(cutlass_cppgen.MathOperation, mo)
if not self.specified_kernel_cc:
if self.current_cc == 90:
# CUTLASS 3.0 kernels do not use different math operations. If one is specified, we
# revert to using a CUTLASS 2.x kernel by using SM80-tagged kernels.
cutlass.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.")
cutlass_cppgen.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.")
self._reset_options(80)
self._reset_operations(reset_epilogue=False)
elif self.current_cc == 90:
@@ -266,7 +266,7 @@ class OperationBase:
self._reset_operations()
def _elements_per_access(self):
if self.op_class == cutlass.OpcodeClass.Simt:
if self.op_class == cutlass_cppgen.OpcodeClass.Simt:
return 1
elif self._element_c != DataType.void:
return 128 // DataTypeSize[self._element_c]
@@ -286,7 +286,7 @@ class OperationBase:
if self.current_cc == 90 and activation != identity:
# CUTLASS 3.0 kernels in Python currently only support identity activation. If one requests a non-identity activation,
# revert to using a CUTLASS 2.x kernel by using SM80-tagged kernels.
cutlass.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.")
cutlass_cppgen.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.")
if self._element_c != self._element_d:
raise Exception("CUTLASS 2.x kernels require element C to be the same as element D")
self._reset_options(80)
@@ -361,7 +361,7 @@ class OperationBase:
"""
if isinstance(act, tuple):
if isinstance(act[0], str):
act_fn = getattr(cutlass.backend.epilogue, act[0])
act_fn = getattr(cutlass_cppgen.backend.epilogue, act[0])
else:
act_fn = act[0]
self._reset_epilogue_functor_activation(act_fn)
@@ -369,7 +369,7 @@ class OperationBase:
self._activation = act[0]
else:
if isinstance(act, str):
act = getattr(cutlass.backend.epilogue, act)
act = getattr(cutlass_cppgen.backend.epilogue, act)
self._reset_epilogue_functor_activation(act)
self._activation = act
@@ -401,8 +401,8 @@ class OperationBase:
td = datatypes.td_from_profiler_op(operation)
# Filter invalid epilogue schedules
if td.epilogue_schedule not in [
cutlass.EpilogueScheduleType.TmaWarpSpecialized,
cutlass.EpilogueScheduleType.TmaWarpSpecializedCooperative]:
cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized,
cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecializedCooperative]:
continue
epilogue_smem_bytes = self.epilogue_functor.get_smem_size(td)
@@ -427,4 +427,4 @@ class OperationBase:
Steps that must be taken before caling `plan.run()`
"""
# Initialize the memory pool if, if not already done
cutlass.get_memory_pool()
cutlass_cppgen.get_memory_pool()

View File

@@ -39,7 +39,7 @@ from cutlass_library import (
ConvKind,
LayoutType
)
from cutlass.backend.c_types import (
from cutlass_cppgen.backend.c_types import (
Conv2DProblemSize_,
GemmCoord_,
GemmCoordBatched_

View File

@@ -30,7 +30,7 @@
#
#################################################################################################
from cutlass.utils.check import (
from cutlass_cppgen.utils.check import (
alignment_or_default,
calculate_smem_usage,
calculate_smem_usage_per_stage,

View File

@@ -38,8 +38,8 @@ import ctypes
from cutlass_library import DataTypeSize, OperationKind, SharedMemPerCC
import cutlass
from cutlass.backend.library import TileDescription
import cutlass_cppgen
from cutlass_cppgen.backend.library import TileDescription
def calculate_smem_usage_per_stage(td: TileDescription, operation_kind: OperationKind) -> int:
@@ -82,8 +82,8 @@ def valid_stage_count(
cc: int,
kernel_cc: int,
td: TileDescription,
element_C: cutlass.DataType = None,
element_D: cutlass.DataType = None,
element_C: cutlass_cppgen.DataType = None,
element_D: cutlass_cppgen.DataType = None,
verbose: bool = True) -> tuple:
"""
Checks whether a device with `cc` supports the number of stages within `tile_description`, both
@@ -96,9 +96,9 @@ def valid_stage_count(
:param td: tile description to check
:type td: TileDescription
:param element_C: data type of operand C
:type element_C: cutlass.DataType
:type element_C: cutlass_cppgen.DataType
:param element_D: data type of operand D
:type element_D: cutlass.DataType
:type element_D: cutlass_cppgen.DataType
:param verbose: whether to log warnings
:type verbose: bool
@@ -112,7 +112,7 @@ def valid_stage_count(
# determines the stage count to use. Thus, all settings are valid in these scenarios.
return (True, "")
elif verbose:
cutlass.logger.warning(
cutlass_cppgen.logger.warning(
"Setting an explicit stage count for SM90 kernels currently may "
"result in compilation errors if the combination of tile shape, "
"stage count, and shared memory requirement of the epilogue exceeds "
@@ -188,9 +188,9 @@ def valid_cluster_shape(cc: int, cluster_shape: list) -> tuple:
def valid_schedule(
cc: int,
kernel_schedule: cutlass.KernelScheduleType,
epilogue_schedule: cutlass.EpilogueScheduleType,
tile_scheduler: cutlass.TileSchedulerType) -> tuple:
kernel_schedule: cutlass_cppgen.KernelScheduleType,
epilogue_schedule: cutlass_cppgen.EpilogueScheduleType,
tile_scheduler: cutlass_cppgen.TileSchedulerType) -> tuple:
"""
Checks that the kernel and epilogue schedules passed in are a valid combination for
a device of compute capability ``cc``.
@@ -198,19 +198,19 @@ def valid_schedule(
:param cc: compute capability of device in question
:type cc: int
:param kernel_schedule: kernel schedule type
:type kernel_schedule: cutlass.KernelScheduleType
:type kernel_schedule: cutlass_cppgen.KernelScheduleType
:param epilogue_schedule: epilogue schedule type
:type epilogue_schedule: cutlass.EpilogueScheduleType
:type epilogue_schedule: cutlass_cppgen.EpilogueScheduleType
:param tile_scheduler: tile scheduler type
:type tile_scheduler: cutlass.TileSchedulerType
:type tile_scheduler: cutlass_cppgen.TileSchedulerType
:return: tuple with the first element indicating whether the provided schedules are
valid for the provided device and the second element being an error message
:rtype: tuple
"""
kernel_auto = (kernel_schedule == cutlass.KernelScheduleType.ScheduleAuto)
epilogue_auto = (epilogue_schedule == cutlass.EpilogueScheduleType.ScheduleAuto)
tile_scheduler_default = (tile_scheduler == cutlass.TileSchedulerType.Default)
kernel_auto = (kernel_schedule == cutlass_cppgen.KernelScheduleType.ScheduleAuto)
epilogue_auto = (epilogue_schedule == cutlass_cppgen.EpilogueScheduleType.ScheduleAuto)
tile_scheduler_default = (tile_scheduler == cutlass_cppgen.TileSchedulerType.Default)
if cc < 90 and not (kernel_auto and epilogue_auto and tile_scheduler_default):
return (False, "Non-default schedules are only supported on SM90 and beyond")
@@ -218,9 +218,9 @@ def valid_schedule(
return (False, "Kernel and epilogue schedules must either both be auto or neither be auto")
if not tile_scheduler_default:
cooperative_kernels = [cutlass.KernelScheduleType.TmaWarpSpecializedCooperative,
cutlass.KernelScheduleType.CpAsyncWarpSpecializedCooperative]
if (tile_scheduler == cutlass.TileSchedulerType.StreamK) and (kernel_schedule not in cooperative_kernels):
cooperative_kernels = [cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedCooperative,
cutlass_cppgen.KernelScheduleType.CpAsyncWarpSpecializedCooperative]
if (tile_scheduler == cutlass_cppgen.TileSchedulerType.StreamK) and (kernel_schedule not in cooperative_kernels):
return (False, "Stream-K tile scheduler is currently only supported with the cooperative kernel schedule")
return (True, "")

View File

@@ -34,13 +34,13 @@
Utility functions for converting between frontend datatypes and CUTLASS datatypes
"""
import cutlass
import cutlass_cppgen
from cutlass_library import (
DataTypeSize,
MathOperation,
MathInstruction
)
from cutlass.backend.library import (
from cutlass_cppgen.backend.library import (
TileDescription,
)
@@ -62,11 +62,11 @@ def is_numpy_available():
numpy_available = True
_library_to_numpy_dict = {
cutlass.DataType.f16: np.float16,
cutlass.DataType.f32: np.float32,
cutlass.DataType.f64: np.float64,
cutlass.DataType.s8: np.int8,
cutlass.DataType.s32: np.int32,
cutlass_cppgen.DataType.f16: np.float16,
cutlass_cppgen.DataType.f32: np.float32,
cutlass_cppgen.DataType.f64: np.float64,
cutlass_cppgen.DataType.s8: np.int8,
cutlass_cppgen.DataType.s32: np.int32,
}
except ImportError:
numpy_available = False
@@ -81,19 +81,19 @@ def is_numpy_tensor(inp) -> bool:
return False
def numpy_library_type(inp) -> cutlass.DataType:
def numpy_library_type(inp) -> cutlass_cppgen.DataType:
if is_numpy_available():
import numpy as np
if inp == np.float16:
return cutlass.DataType.f16
return cutlass_cppgen.DataType.f16
elif inp == np.float32:
return cutlass.DataType.f32
return cutlass_cppgen.DataType.f32
elif inp == np.float64:
return cutlass.DataType.f64
return cutlass_cppgen.DataType.f64
elif inp == np.int8:
return cutlass.DataType.s8
return cutlass_cppgen.DataType.s8
elif inp == np.int32:
return cutlass.DataType.s32
return cutlass_cppgen.DataType.s32
return None
@@ -109,11 +109,11 @@ def is_cupy_available():
cupy_available = True
_library_to_cupy_dict = {
cutlass.DataType.f16: cp.float16,
cutlass.DataType.f32: cp.float32,
cutlass.DataType.f64: cp.float64,
cutlass.DataType.s8: cp.int8,
cutlass.DataType.s32: cp.int32,
cutlass_cppgen.DataType.f16: cp.float16,
cutlass_cppgen.DataType.f32: cp.float32,
cutlass_cppgen.DataType.f64: cp.float64,
cutlass_cppgen.DataType.s8: cp.int8,
cutlass_cppgen.DataType.s32: cp.int32,
}
except ImportError:
cupy_available = False
@@ -128,15 +128,15 @@ def is_cupy_tensor(inp) -> bool:
return False
def cupy_library_type(inp) -> cutlass.DataType:
def cupy_library_type(inp) -> cutlass_cppgen.DataType:
if is_cupy_available():
import cupy as cp
if inp == cp.float16:
return cutlass.DataType.f16
return cutlass_cppgen.DataType.f16
elif inp == cp.float32:
return cutlass.DataType.f32
return cutlass_cppgen.DataType.f32
elif inp == cp.float64:
return cutlass.DataType.f64
return cutlass_cppgen.DataType.f64
return None
@@ -152,29 +152,29 @@ def is_torch_available():
torch_available = True
_torch_to_library_dict = {
torch.half: cutlass.DataType.f16,
torch.float16: cutlass.DataType.f16,
torch.bfloat16: cutlass.DataType.bf16,
torch.float: cutlass.DataType.f32,
torch.float32: cutlass.DataType.f32,
torch.double: cutlass.DataType.f64,
torch.float64: cutlass.DataType.f64,
torch.int8: cutlass.DataType.s8,
torch.int32: cutlass.DataType.s32,
torch.uint8: cutlass.DataType.u8,
torch.half: cutlass_cppgen.DataType.f16,
torch.float16: cutlass_cppgen.DataType.f16,
torch.bfloat16: cutlass_cppgen.DataType.bf16,
torch.float: cutlass_cppgen.DataType.f32,
torch.float32: cutlass_cppgen.DataType.f32,
torch.double: cutlass_cppgen.DataType.f64,
torch.float64: cutlass_cppgen.DataType.f64,
torch.int8: cutlass_cppgen.DataType.s8,
torch.int32: cutlass_cppgen.DataType.s32,
torch.uint8: cutlass_cppgen.DataType.u8,
}
_library_to_torch_dict = {
cutlass.DataType.f16: torch.half,
cutlass.DataType.f16: torch.float16,
cutlass.DataType.bf16: torch.bfloat16,
cutlass.DataType.f32: torch.float,
cutlass.DataType.f32: torch.float32,
cutlass.DataType.f64: torch.double,
cutlass.DataType.f64: torch.float64,
cutlass.DataType.s8: torch.int8,
cutlass.DataType.s32: torch.int32,
cutlass.DataType.u8: torch.uint8,
cutlass_cppgen.DataType.f16: torch.half,
cutlass_cppgen.DataType.f16: torch.float16,
cutlass_cppgen.DataType.bf16: torch.bfloat16,
cutlass_cppgen.DataType.f32: torch.float,
cutlass_cppgen.DataType.f32: torch.float32,
cutlass_cppgen.DataType.f64: torch.double,
cutlass_cppgen.DataType.f64: torch.float64,
cutlass_cppgen.DataType.s8: torch.int8,
cutlass_cppgen.DataType.s32: torch.int32,
cutlass_cppgen.DataType.u8: torch.uint8,
}
def possibly_add_type(torch_type_name, cutlass_type):
@@ -184,8 +184,8 @@ def is_torch_available():
_torch_to_library_dict[torch_type] = cutlass_type
_library_to_torch_dict[cutlass_type] = torch_type
possibly_add_type("float8_e4m3fn", cutlass.DataType.e4m3)
possibly_add_type("float8_e5m2", cutlass.DataType.e5m2)
possibly_add_type("float8_e4m3fn", cutlass_cppgen.DataType.e4m3)
possibly_add_type("float8_e5m2", cutlass_cppgen.DataType.e5m2)
except ImportError:
torch_available = False
@@ -201,7 +201,7 @@ def is_torch_tensor(inp) -> bool:
return False
def torch_library_type(inp) -> cutlass.DataType:
def torch_library_type(inp) -> cutlass_cppgen.DataType:
return _torch_to_library_dict.get(inp, None)
@@ -222,17 +222,17 @@ def is_bfloat16_available():
return bfloat16_available
def bfloat16_library_type(inp) -> cutlass.DataType:
def bfloat16_library_type(inp) -> cutlass_cppgen.DataType:
if is_bfloat16_available():
import bfloat16
if inp == bfloat16.bfloat16:
return cutlass.DataType.bf16
return cutlass_cppgen.DataType.bf16
def bfloat16_type(inp):
if is_bfloat16_available():
import bfloat16
if inp == cutlass.DataType.bf16:
if inp == cutlass_cppgen.DataType.bf16:
return bfloat16.bfloat16
@@ -256,15 +256,15 @@ def library_type(inp):
def _tensor_from_numpy(np_tensor):
dtype = library_type(np_tensor.dtype)
if np_tensor.flags.c_contiguous:
layout = cutlass.LayoutType.RowMajor
layout = cutlass_cppgen.LayoutType.RowMajor
elif np_tensor.flags.f_contiguous:
layout = cutlass.LayoutType.ColumnMajor
layout = cutlass_cppgen.LayoutType.ColumnMajor
return (dtype, layout)
def _tensor_from_torch(pt_tensor):
dtype = library_type(pt_tensor.dtype)
return (dtype, cutlass.LayoutType.RowMajor)
return (dtype, cutlass_cppgen.LayoutType.RowMajor)
def get_datatype_and_layout(tensor):
@@ -273,7 +273,7 @@ def get_datatype_and_layout(tensor):
elif is_torch_tensor(tensor):
return _tensor_from_torch(tensor)
elif isinstance(tensor, float) or isinstance(tensor, int):
return (cutlass.DataType.f32, cutlass.LayoutType.RowMajor)
return (cutlass_cppgen.DataType.f32, cutlass_cppgen.LayoutType.RowMajor)
else:
raise Exception(f"Unable to convert tensor of type {type(tensor)} to Python-bound CUTLASS datatype and layout.")
@@ -303,10 +303,10 @@ def backend_math_operation(math_op: MathOperation):
return _math_operation_value_map[math_op.value]
def construct_backend_td(td: cutlass.TileDescription,
kernel_schedule: cutlass.KernelScheduleType,
epilogue_schedule: cutlass.EpilogueScheduleType,
tile_scheduler: cutlass.TileSchedulerType) -> TileDescription:
def construct_backend_td(td: cutlass_cppgen.TileDescription,
kernel_schedule: cutlass_cppgen.KernelScheduleType,
epilogue_schedule: cutlass_cppgen.EpilogueScheduleType,
tile_scheduler: cutlass_cppgen.TileSchedulerType) -> TileDescription:
mi = td.math_instruction
backend_mi = MathInstruction(
mi.instruction_shape,
@@ -328,7 +328,7 @@ def td_from_profiler_op(op) -> TileDescription:
:param op: profiler Operation
:returns: backend TileDescription
:rtype: cutlass.backend.TileDescription
:rtype: cutlass_cppgen.backend.TileDescription
"""
kschedule = op.kernel_schedule if hasattr(op, 'kernel_schedule') else None
eschedule = op.epilogue_schedule if hasattr(op, 'epilogue_schedule') else None
@@ -341,10 +341,10 @@ def td_from_profiler_td(td: TileDescription) -> TileDescription:
Converts the profiler's TileDescription into the backend TileDescription
:param td: profiler TileDescription
:type td: cutlass.TileDescription
:type td: cutlass_cppgen.TileDescription
:returns: backend TileDescription
:rtype: cutlass.backend.TileDescription
:rtype: cutlass_cppgen.backend.TileDescription
"""
return construct_backend_td(td, kernel_schedule=None, epilogue_schedule=None, tile_scheduler=None)

View File

@@ -37,16 +37,16 @@ Profiler based on the cuda events
import re
import subprocess
from cutlass.utils.lazy_import import lazy_import
from cutlass_cppgen.utils.lazy_import import lazy_import
cuda = lazy_import("cuda.cuda")
cudart = lazy_import("cuda.cudart")
import numpy as np
from cutlass import CUTLASS_PATH
from cutlass.backend.library import DataTypeSize
from cutlass.op.op import OperationBase
from cutlass.shape import GemmCoord
from cutlass.utils.datatypes import is_numpy_tensor
from cutlass_cppgen import CUTLASS_PATH
from cutlass_cppgen.backend.library import DataTypeSize
from cutlass_cppgen.op.op import OperationBase
from cutlass_cppgen.shape import GemmCoord
from cutlass_cppgen.utils.datatypes import is_numpy_tensor
class GpuTimer:

View File

@@ -49,7 +49,6 @@ from . import rank_2k_operation
from . import rank_k_operation
from . import symm_operation
from . import trmm_operation
# Make enum types from library.py accessible via cutlass_library.*
from .library import *

View File

@@ -279,7 +279,7 @@ def _computeFlopsPerByte(operation, m, n, k, batch_count=1, beta=0.0, num_groups
def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
):
# For functional testing, we prefer to run reference computing on device if any
reference_device_archs = ["100a"]
reference_device_archs = ["100a", "103a"]
run_reference_on_device = True if arch in reference_device_archs and mode in ["functional_L0", "functional_L1"] else False
profiler_flags_for_verification = "device" if run_reference_on_device else "host"
@@ -287,7 +287,7 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
# TODO: randomize beta values for wider coverage
beta_values = [0.5]
is_supported_arch = (arch in ["100a", "100f", "101a", "101f", "120a", "120f"])
is_supported_arch = (arch in ["100a", "100f", "101a", "101f", "103a", "110a", "110f", "120a", "120f", "121a", "121f"])
is_runtime_datatype_enabled = mode == "functional_L0" and is_supported_arch
@@ -306,6 +306,10 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
'bf16gemm_f32_f32_f32_f32_f32',
]
exclude_archs = arch not in ("103a")
if exclude_archs:
sm100_mma_data_type_general.append('gemm_s8_s8_s32_s8_s8')
sm100_mma_data_type_runtime_dtype = [
'gemm.*f4_f4_f32_f32_f32',
'gemm.*f6_f6_f32_f32_f32',
@@ -344,6 +348,11 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
'gemm.*ue8m0xf6_ue8m0xf6_f32_f16_ue8m0xe3m2',
]
sm103_block_scaled_data_type = [
'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2',
'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_ue8m0xe2m1',
]
block_scaled_cluster_size = [
'4x4x1', '2x1x1',
'0x0x1' # dynamic cluster
@@ -354,6 +363,9 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
sm103_block_scaled_filter_regex_1sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
sm103_block_scaled_filter_regex_2sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
if arch in ["100a", "100f"]:
kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \
f"({sm100_mma_filter_regex_2sm})|" \
@@ -361,15 +373,23 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
f"({sm100_mma_filter_regex_2sm_runtime})|" \
f"({block_scaled_filter_regex_1sm})|" \
f"({block_scaled_filter_regex_2sm})"
elif arch in ["101a", "101f",
]:
elif arch in ["101a", "101f", "110a", "110f"]:
kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \
f"({sm100_mma_filter_regex_2sm})|" \
f"({sm100_mma_filter_regex_1sm_runtime})|" \
f"({sm100_mma_filter_regex_2sm_runtime})|" \
f"({block_scaled_filter_regex_1sm})|" \
f"({block_scaled_filter_regex_2sm})"
elif arch in ["120a", "120f"]:
elif arch in ["103a"]:
kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \
f"({sm100_mma_filter_regex_2sm})|" \
f"({sm100_mma_filter_regex_1sm_runtime})|" \
f"({sm100_mma_filter_regex_2sm_runtime})|" \
f"({block_scaled_filter_regex_1sm})|" \
f"({block_scaled_filter_regex_2sm})|" \
f"({sm103_block_scaled_filter_regex_1sm})|" \
f"({sm103_block_scaled_filter_regex_2sm})"
elif arch in ["120a", "120f", "121a", "121f"]:
# blockscaled sm120_mma kernels
blockscaled_sm120_mma_kernel_cta_tiles = [
@@ -384,7 +404,7 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
kernel_filter = f"({filter_regex_blockscaled_sm120_mma})"
else:
error_message = "unsupported arch, only support sm100a, sm100f, sm101a, sm101f, sm120a, sm120f"
error_message = "unsupported arch, only support sm100a, sm100f, sm101a, sm101f, sm110a, sm110f, sm103a, sm120a, sm120f, sm121a, sm121f"
raise Exception(error_message)
elif mode == "functional_L1":
@@ -403,16 +423,27 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
'ue8m0xe2m3_ue8m0xe2m3_f32_f16_ue8m0xe3m2',
]
block_scaled_cluster_size = ['4x4x1', '2x1x1', '0x0x1']
sm103_block_scaled_data_type = [
'ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2',
'ue8m0xe2m1_ue8m0xe2m1_f32_f16_ue8m0xe2m1',
]
block_scaled_cluster_size = ['0x0x1']
block_scaled_layouts = ['tnt']
# regex list must be in kernel procedural name order
block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
sm103_block_scaled_filter_regex_1sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
sm103_block_scaled_filter_regex_2sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
filter_regex_sm100_mma = f"({sm100_mma_filter_regex_1sm})|" \
f"({sm100_mma_filter_regex_2sm})|" \
f"({block_scaled_filter_regex_1sm})|" \
f"({block_scaled_filter_regex_2sm})"
f"({block_scaled_filter_regex_2sm})" \
f"({sm103_block_scaled_filter_regex_1sm})|" \
f"({sm103_block_scaled_filter_regex_2sm})"
# CTA tiles for sm120 MMA - only run one tile size to reduce build/test times
sm120_mma_kernel_cta_tiles = [
# h1688, s1688, i16832, i8816
@@ -449,7 +480,10 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
problem_waves = [0.5, 1.25, 2.5]
kernel_filter = f"({filter_regex_sm100_mma})|({filter_regex_sm120_mma})"
if arch in ["120a", "120f", "121a", "121f"]:
kernel_filter = f"({filter_regex_sm120_mma})"
else:
kernel_filter = f"({filter_regex_sm100_mma})"
else:
raise ValueError()

View File

@@ -341,7 +341,7 @@ class GemmOperation:
Get the tile shape passed to the collective builder.
On Blackwell, this is different than the operation.tile_description.tile_shape.
"""
is_sm100_kernel = (self.arch == 100)
is_sm100_kernel = (self.arch == 100 or self.arch == 103)
if not is_sm100_kernel:
return self.tile_description.tile_shape
@@ -995,6 +995,24 @@ ${compile_guard_end}
epi_tile_mn = "cute::Shape<cute::_128,cute::_64>"
if not is_no_smem_epilogue:
epilogue_schedule_type = EpilogueScheduleTag[to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized2Sm, grouped)]
if cta_n == 256 and operation.kernel_schedule == KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103:
epi_tile_mn = "cute::Shape<cute::_128,cute::_64>"
if not is_no_smem_epilogue:
epilogue_schedule_type = EpilogueScheduleTag[EpilogueScheduleType.TmaWarpSpecialized1Sm]
if cta_n == 256 and operation.kernel_schedule == KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103:
epi_tile_mn = "cute::Shape<cute::_128,cute::_64>"
if not is_no_smem_epilogue:
epilogue_schedule_type = EpilogueScheduleTag[EpilogueScheduleType.TmaWarpSpecialized2Sm]
if cta_n == 256 and operation.kernel_schedule == KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103:
epi_tile_mn = "cute::Shape<cute::_128,cute::_64>"
if not is_no_smem_epilogue:
epilogue_schedule_type = EpilogueScheduleTag[EpilogueScheduleType.TmaWarpSpecialized1Sm]
if cta_n == 256 and operation.kernel_schedule == KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103:
epi_tile_mn = "cute::Shape<cute::_128,cute::_64>"
if not is_no_smem_epilogue:
epilogue_schedule_type = EpilogueScheduleTag[EpilogueScheduleType.TmaWarpSpecialized2Sm]
element_a = f'cute::tuple<{str(element_a)},{str(DataTypeTag[operation.ScaleFactorA])}>'
element_b = f'cute::tuple<{str(element_b)},{str(DataTypeTag[operation.ScaleFactorB])}>'

View File

@@ -90,10 +90,12 @@ try:
raise ImportError("Disabling attempt to import cutlass_library")
from cutlass_library.library import *
from cutlass_library.manifest import *
from cutlass_library.heuristics import *
from cutlass_library.emit_kernel_listing import emit_gemm_kernel_testlist
except ImportError:
from library import *
from manifest import *
from heuristics import *
from emit_kernel_listing import emit_gemm_kernel_testlist
###################################################################################################
@@ -112,6 +114,10 @@ def CudaToolkitVersionSatisfies(semantic_ver_string, major, minor, patch = 0):
cuda_version.append(x)
return cuda_version >= [major, minor, patch]
# From cuda 13.0, Thor SM is renumbered from 101 to 110
def ThorSMRenumbering(cuda_version):
return 110 if CudaToolkitVersionSatisfies(cuda_version, 13, 0) else 101
###################################################################################################
###################################################################################################
@@ -6768,9 +6774,11 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version):
},
]
thor_sm = 101
thor_sm = ThorSMRenumbering(cuda_version)
min_cc = 100
max_cc = thor_sm
math_instructions_1sm = [
# tf32 -> f32
MathInstruction(
@@ -6887,7 +6895,8 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK
[[LayoutType.RowMajor, 8], [LayoutType.RowMajor, 8], [LayoutType.RowMajor, 0]],
]
thor_sm = 101
thor_sm = ThorSMRenumbering(cuda_version)
min_cc = 100
max_cc = thor_sm
grouped = is_grouped(gemm_kind)
@@ -7202,9 +7211,11 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK
[[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 0]],
]
thor_sm = 101
thor_sm = ThorSMRenumbering(cuda_version)
min_cc = 100
max_cc = thor_sm
epi_type = DataType.f32
grouped = is_grouped(gemm_kind)
@@ -7889,9 +7900,11 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version):
TileSchedulerType.Default, TileSchedulerType.StreamK
]
thor_sm = 101
thor_sm = ThorSMRenumbering(cuda_version)
min_cc = 100
max_cc = thor_sm
epi_type = DataType.f32
math_instructions_1sm = []
@@ -8092,6 +8105,8 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
if not CudaToolkitVersionSatisfies(cuda_version, 12, 8):
return
grouped = is_grouped(gemm_kind)
layouts = [
[[LayoutType.RowMajor, 128], [LayoutType.ColumnMajor, 128], [LayoutType.RowMajor, 0]],
[[LayoutType.ColumnMajor, 128], [LayoutType.RowMajor, 128], [LayoutType.RowMajor, 0]],
@@ -8120,14 +8135,16 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
def tile_schedulers(sfdtype):
# Only use the stream-K scheduler for non-void SFD to limit kernel count. When SFD is void,
# the epilogue is the traditional linear combination, for which we already have tests with stream-K.
if sfdtype["type"] == DataType.void:
if sfdtype["type"] == DataType.void or grouped:
return [TileSchedulerType.Default]
else:
return [TileSchedulerType.Default, TileSchedulerType.StreamK]
thor_sm = 101
thor_sm = ThorSMRenumbering(cuda_version)
min_cc = 100
max_cc = thor_sm
epi_type = DataType.f32
math_instructions_1sm = []
@@ -8209,6 +8226,16 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
"sf_type" : math_inst.element_scale_factor,
"sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None}
},
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
"c_type" : DataType.void,
"d_type" : DataType.bf16,
"acc_type" : math_inst.element_accumulator,
"epi_type" : epi_type,
"sf_type" : math_inst.element_scale_factor,
"sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None}
},
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
@@ -8246,7 +8273,7 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
for data_type in data_types:
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
[[KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]]
[[to_grouped_schedule(KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100, grouped), to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped)]]
, tile_schedulers = tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind)
cluster_shapes_2sm = [
@@ -8288,6 +8315,16 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
"sf_type" : math_inst.element_scale_factor,
"sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None}
},
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
"c_type" : DataType.void,
"d_type" : DataType.bf16,
"acc_type" : math_inst.element_accumulator,
"epi_type" : epi_type,
"sf_type" : math_inst.element_scale_factor,
"sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None}
},
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
@@ -8346,7 +8383,11 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
if tile.threadblock_shape[2] // tile.cluster_shape[2] % layout[1][1] != 0:
continue
if math_inst.instruction_shape[0] == 128:
if grouped:
CreateGemmUniversal3xOperator(manifest, [layout], [tile], [data_type],
[[to_grouped_schedule(KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100, grouped), to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized2Sm, grouped)]]
, tile_schedulers = tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind)
elif math_inst.instruction_shape[0] == 128:
CreateGemmUniversal3xOperator(manifest, [layout], [tile], [data_type],
[[KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm]]
, tile_schedulers = tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind)
@@ -8396,9 +8437,11 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
else:
return [TileSchedulerType.Default, TileSchedulerType.StreamK]
thor_sm = 101
thor_sm = ThorSMRenumbering(cuda_version)
min_cc = 100
max_cc = thor_sm
epi_type = DataType.f32
math_instructions_1sm = []
@@ -8496,6 +8539,16 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
"sf_type" : math_inst.element_scale_factor,
"sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None}
},
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
"c_type" : DataType.bf16,
"d_type" : DataType.bf16,
"acc_type" : math_inst.element_accumulator,
"epi_type" : epi_type,
"sf_type" : math_inst.element_scale_factor,
"sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None}
},
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
@@ -8625,6 +8678,16 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
"sf_type" : math_inst.element_scale_factor,
"sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None}
},
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
"c_type" : DataType.bf16,
"d_type" : DataType.bf16,
"acc_type" : math_inst.element_accumulator,
"epi_type" : epi_type,
"sf_type" : math_inst.element_scale_factor,
"sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None}
},
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
@@ -8715,6 +8778,230 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
def GenerateSM103_TensorOp_fp4_ultra_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.BlockScaledUniversal3x):
# SM100 MMA with F4 + block scale
if not CudaToolkitVersionSatisfies(cuda_version, 12, 8):
return
# layouts for ABC and their alignments.
layouts = [
[[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 32], [LayoutType.RowMajor, 0]],
]
instruction_sizes_1sm = [
[128, 128, 96],
]
instruction_sizes_2sm = [
[256, 128, 96],
]
ab_types = [
DataType.e2m1,
]
acc_types = [ DataType.f32 ] # Accumulator is always 32 bits for block scaled MMA instructions
min_cc = 103
max_cc = 103
epi_type = DataType.f32
math_instructions_1sm = []
is_runtime_datatype = lambda runtime_datatype: runtime_datatype in (DataType.f4, DataType.f6, DataType.f8)
for instr_size, a_type, b_type, acc_type in product(instruction_sizes_1sm, ab_types, ab_types, acc_types):
is_runtime_datatype_a = is_runtime_datatype(a_type)
is_runtime_datatype_b = is_runtime_datatype(b_type)
# A/B datatypes should be both static or dynamic
if (is_runtime_datatype_a != is_runtime_datatype_b):
continue
math_instructions_1sm.append(
MathInstruction(
instr_size,
a_type, b_type, acc_type,
OpcodeClass.BlockScaledTensorOp,
MathOperation.multiply_add,
DataType.ue8m0) # UE8M0 scale factor
)
math_instructions_2sm = []
for instr_size, a_type, b_type, acc_type in product(instruction_sizes_2sm, ab_types, ab_types, acc_types):
is_runtime_datatype_a = is_runtime_datatype(a_type)
is_runtime_datatype_b = is_runtime_datatype(b_type)
# A/B datatypes should be both static or dynamic
if (is_runtime_datatype_a != is_runtime_datatype_b):
continue
math_instructions_2sm.append(
MathInstruction(
instr_size,
a_type, b_type, acc_type,
OpcodeClass.BlockScaledTensorOp,
MathOperation.multiply_add,
DataType.ue8m0) # UE8M0 scale factor
)
cluster_shapes_1sm = [
[1,1,1],
# [1,2,1],
[2,1,1],
# [1,4,1],
[4,4,1]
, DynamicClusterShape
]
# 1xSM MMA kernels
for math_inst in math_instructions_1sm:
tile_descriptions = []
for cluster_shape in cluster_shapes_1sm:
multiplier_1sm = cluster_shape
tile_descriptions.append(
TileDescription([
math_inst.instruction_shape[0] * multiplier_1sm[0],
math_inst.instruction_shape[1] * multiplier_1sm[1],
768],
0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape))
data_types = [
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
"c_type" : DataType.void,
"d_type" : DataType.f32,
"acc_type" : math_inst.element_accumulator,
"epi_type" : epi_type,
"sf_type" : math_inst.element_scale_factor,
"sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None}
},
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
"c_type" : DataType.bf16,
"d_type" : DataType.bf16,
"acc_type" : math_inst.element_accumulator,
"epi_type" : epi_type,
"sf_type" : math_inst.element_scale_factor,
"sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None}
},
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
"c_type" : DataType.void,
"d_type" : DataType.e2m1,
"acc_type" : math_inst.element_accumulator,
"epi_type" : epi_type,
"sf_type" : math_inst.element_scale_factor,
"sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor}
},
]
for layout in layouts:
for data_type in data_types:
# Set alignment d based on Destination format.
if DataTypeSize[data_type["c_type"]] == 0 :
layout[2][1] = 256 // DataTypeSize[data_type["d_type"]]
else:
layout[2][1] = min(256 // DataTypeSize[data_type["d_type"]], 256 // DataTypeSize[data_type["c_type"]])
if data_type["sfd_type"]["type"] != DataType.void and (data_type["d_type"] == DataType.e2m1):
data_type["sfd_type"]["layout"] = layout[2][0] # For FP4 output , the scalefactor layout is same layout as D layout.
# E2M1 x E2M1, vector size 32, E8
isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1
fp4_schedule = [KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103, EpilogueScheduleType.NoSmemWarpSpecialized1Sm]
fp4_schedule_disable_prefetch = [KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch, EpilogueScheduleType.NoSmemWarpSpecialized1Sm]
fp4_schedule_enable_prefetch = [KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch, EpilogueScheduleType.NoSmemWarpSpecialized1Sm]
# For FP4 inputs
if isFp4:
CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, [fp4_schedule, fp4_schedule_disable_prefetch
,fp4_schedule_enable_prefetch
]
, gemm_kind=gemm_kind
)
cluster_shapes_2sm = [
[2,1,1],
# [2,2,1],
# [2,4,1],
[4,1,1],
# [4,2,1],
[4,4,1]
, DynamicClusterShape
]
for math_inst in math_instructions_2sm:
tile_descriptions = []
for cluster_shape in cluster_shapes_2sm:
multiplier_2sm = (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2])
tile_descriptions.append(
TileDescription([
math_inst.instruction_shape[0] * multiplier_2sm[0],
math_inst.instruction_shape[1] * multiplier_2sm[1],
math_inst.instruction_shape[2] * 8 * multiplier_2sm[2]],
0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape))
data_types = [
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
"c_type" : DataType.void,
"d_type" : DataType.f32,
"acc_type" : math_inst.element_accumulator,
"epi_type" : epi_type,
"sf_type" : math_inst.element_scale_factor,
"sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None}
},
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
"c_type" : DataType.bf16,
"d_type" : DataType.bf16,
"acc_type" : math_inst.element_accumulator,
"epi_type" : epi_type,
"sf_type" : math_inst.element_scale_factor,
"sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None}
},
{
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
"c_type" : DataType.void,
"d_type" : DataType.e2m1,
"acc_type" : math_inst.element_accumulator,
"epi_type" : epi_type,
"sf_type" : math_inst.element_scale_factor,
"sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor}
},
]
for layout in layouts:
for data_type in data_types:
# Set alignment d based on Destination format.
if DataTypeSize[data_type["c_type"]] == 0 :
layout[2][1] = 256 // DataTypeSize[data_type["d_type"]]
else:
layout[2][1] = min(256 // DataTypeSize[data_type["d_type"]], 256 // DataTypeSize[data_type["c_type"]])
if data_type["sfd_type"]["type"] != DataType.void and (data_type["d_type"] == DataType.e2m1):
data_type["sfd_type"]["layout"] = layout[2][0] # For FP4 output , the scalefactor layout is same layout as D layout.
# E2M1 x E2M1, vector size 32, E8
isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1
fp4_schedule = [KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103, EpilogueScheduleType.NoSmemWarpSpecialized2Sm]
fp4_schedule_disable_prefetch = [KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch, EpilogueScheduleType.NoSmemWarpSpecialized2Sm]
fp4_schedule_enable_prefetch = [KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch, EpilogueScheduleType.NoSmemWarpSpecialized2Sm]
# For FP4 inputs
if isFp4:
CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, [fp4_schedule, fp4_schedule_disable_prefetch
,fp4_schedule_enable_prefetch
]
, gemm_kind=gemm_kind
)
def GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version):
if not CudaToolkitVersionSatisfies(cuda_version, 12, 8):
@@ -8732,7 +9019,8 @@ def GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version):
[[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 0]],
]
thor_sm = 101
thor_sm = ThorSMRenumbering(cuda_version)
min_cc = 100
max_cc = thor_sm
@@ -8948,9 +9236,11 @@ def GenerateSM100_SparseTensorOp_32b_UMMA_gemm(manifest, cuda_version):
[[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]],
]
thor_sm = 101
thor_sm = ThorSMRenumbering(cuda_version)
min_cc = 100
max_cc = thor_sm
tile_schedulers = [
TileSchedulerType.Default,
]
@@ -9074,9 +9364,11 @@ def GenerateSM100_SparseTensorOp_16b_UMMA_gemm(manifest, cuda_version):
[[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]],
]
thor_sm = 101
thor_sm = ThorSMRenumbering(cuda_version)
min_cc = 100
max_cc = thor_sm
tile_schedulers = [
TileSchedulerType.Default,
]
@@ -9200,7 +9492,8 @@ def GenerateSM100_SparseTensorOp_int8_UMMA_gemm(manifest, cuda_version):
[[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]],
]
thor_sm = 101
thor_sm = ThorSMRenumbering(cuda_version)
min_cc = 100
max_cc = thor_sm
@@ -9326,9 +9619,11 @@ def GenerateSM100_SparseTensorOp_fp8_UMMA_gemm(manifest, cuda_version):
[[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]],
]
thor_sm = 101
thor_sm = ThorSMRenumbering(cuda_version)
min_cc = 100
max_cc = thor_sm
tile_schedulers = [
TileSchedulerType.Default,
]
@@ -9465,9 +9760,11 @@ def GenerateSM100_SparseTensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version):
[[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]],
]
thor_sm = 101
thor_sm = ThorSMRenumbering(cuda_version)
min_cc = 100
max_cc = thor_sm
tile_schedulers = [
TileSchedulerType.Default,
]
@@ -9678,9 +9975,11 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm_stream_k(manifest, cuda_version):
}
]
thor_sm = 101
thor_sm = ThorSMRenumbering(cuda_version)
min_cc = 100
max_cc = thor_sm
math_instructions_1sm = [
MathInstruction(
[128, 256, 8],
@@ -9772,9 +10071,11 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm_stream_k(manifest, cuda_version):
[[LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.RowMajor, 0]],
]
thor_sm = 101
thor_sm = ThorSMRenumbering(cuda_version)
min_cc = 100
max_cc = thor_sm
math_instructions_1sm = [
MathInstruction(
[128, 256, 16],
@@ -9934,9 +10235,11 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm_stream_k(manifest, cuda_version):
[[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 0]],
]
thor_sm = 101
thor_sm = ThorSMRenumbering(cuda_version)
min_cc = 100
max_cc = thor_sm
epi_type = DataType.f32
math_instructions_1sm = [
@@ -10084,7 +10387,8 @@ def GenerateSM100_TensorOp_16b_UMMA_conv3x(manifest, cuda_version,
if not CudaToolkitVersionSatisfies(cuda_version, 12, 0):
return
thor_sm = 101
thor_sm = ThorSMRenumbering(cuda_version)
minimum_compute_capability = 100
maximum_compute_capability = thor_sm
@@ -10238,7 +10542,8 @@ def GenerateSM100_TensorOp_fp8_UMMA_conv3x(manifest, cuda_version,
if not CudaToolkitVersionSatisfies(cuda_version, 12, 0):
return
thor_sm = 101
thor_sm = ThorSMRenumbering(cuda_version)
minimum_compute_capability = 100
maximum_compute_capability = thor_sm
@@ -10422,7 +10727,7 @@ def GenerateSM120_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
return [TileSchedulerType.Default, TileSchedulerType.StreamK]
min_cc = 120
max_cc = 120
max_cc = 121
epi_type = DataType.f32
@@ -10567,7 +10872,7 @@ def GenerateSM120_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
return [TileSchedulerType.Default, TileSchedulerType.StreamK]
min_cc = 120
max_cc = 120
max_cc = 121
epi_type = DataType.f32
@@ -10720,7 +11025,7 @@ def GenerateSM120_Sparse_TensorOp_gemm(manifest, cuda_version):
return [TileSchedulerType.Default]
min_cc = 120
max_cc = 120
max_cc = 121
kernel_schedules = [
KernelScheduleType.F8f6f4SparseTmaWarpSpecializedCooperativeSm120,
@@ -10840,7 +11145,7 @@ def GenerateSM120_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version,
return [TileSchedulerType.Default]
min_cc = 120
max_cc = 120
max_cc = 121
kernel_schedulers = [
KernelScheduleType.BlockwiseTmaWarpSpecializedCooperativeSm120,
@@ -10924,7 +11229,11 @@ def GenerateSM120_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version,
gemm_kind = gemm_kind)
def GenerateSM100(manifest, cuda_version):
arch_family_cc = ['100f', '101f']
arch_family_cc = ['100f', '101f', '103a']
if CudaToolkitVersionSatisfies(cuda_version, 13, 0):
for old_cc, new_cc in [('101f', '110f')]:
arch_family_cc = [cc.replace(old_cc, new_cc) for cc in arch_family_cc]
#
# Dense Gemm
#
@@ -10966,8 +11275,11 @@ def GenerateSM100(manifest, cuda_version):
# Block Scaled Gemm
#
GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version)
GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockScaledUniversal3x)
GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version)
GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockScaledUniversal3x)
GenerateSM103_TensorOp_fp4_ultra_UMMA_gemm_with_block_scaled(manifest, cuda_version)
#
# Conv
#
@@ -11413,7 +11725,6 @@ def numeric_log_level(log_level: str) -> int:
raise ValueError(f'Invalid log level: {log_level}')
return numeric_level
# This function for defining the ArgumentParser is used to make it easy for the CUTLASS Python interface
# to leverage the functionality in this file without running this script via a shell prompt.
def define_parser():
@@ -11438,6 +11749,11 @@ def define_parser():
parser.add_argument("--filter-by-cc", default='True', type=str, help='If enabled, kernels whose compute capability range is not satisfied by the build target are excluded.')
parser.add_argument("--cuda-version", default="11.0.0", help="Semantic version string of CUDA Toolkit")
parser.add_argument('--kernel-filter-file', type=str, default=None, required=False, help='Full path of filter file')
parser.add_argument('--heuristics-problems-file', type=str, default=None, required=False, help='Full path of heuristics problem size description file, as a json list')
parser.add_argument('--heuristics-testlist-file', type=str, default=None, required=False, help='Full path of heuristics testlist CSV file, to be passed to cutlass_profiler')
parser.add_argument('--heuristics-gpu', type=str, default=None, required=False, help='GPU to use for evaluating heuristics offline. None or `auto` to autodetect using cuda', choices=['', 'auto', 'H100_SXM', 'H100_PCIE', 'H100_NVL', 'H200_SXM', 'H20_SXM', 'B200', 'GB200_NVL', 'RTX_5080', 'RTX_5090', 'RTX_PRO_6000'])
parser.add_argument('--heuristics-configs-per-problem', type=int, default=10, required=False, help='Number of kernel configs to generate for each problem in the problem list')
parser.add_argument('--heuristics-restrict-kernels', action='store_true', help='Restrict heuristics mode to use only the default set of kernels emitted by generator.py')
parser.add_argument('--selected-kernel-list', type=str, default=None, required=False,
help='Specify the output log file containing all enabled kernels in this build')
parser.add_argument("--interface-dir", default=None, required=False, help="Interface header to kernels")
@@ -11460,6 +11776,9 @@ if __name__ == "__main__":
archs = args.architectures.split(';')
if args.heuristics_problems_file:
filter_manifest_and_write_heuristics_file(manifest, args)
GenerateSM50(manifest, args.cuda_version)
GenerateSM60(manifest, args.cuda_version)
GenerateSM61(manifest, args.cuda_version)
@@ -11468,17 +11787,20 @@ if __name__ == "__main__":
GenerateSM80(manifest, args.cuda_version)
GenerateSM89(manifest, args.cuda_version)
GenerateSM90(manifest, args.cuda_version)
blackwell_arch_list = [
"100a", "100f",
"101a", "101f",
"120a", "120f"
"103a", "103f",
"110a", "110f",
"120a", "120f",
"121a", "121f",
]
blackwell_enabled_arch = any(arch in blackwell_arch_list for arch in archs)
if blackwell_enabled_arch:
GenerateSM100(manifest, args.cuda_version)
GenerateSM120(manifest, args.cuda_version)
if 'library' in args.generator_target.split(','):
manifest.emit(GeneratorTarget.Library)

View File

@@ -0,0 +1,414 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Utilities for selecting CUTLASS library kernels based on problem description
"""
import json
import csv
try:
if CUTLASS_IGNORE_PACKAGE:
raise ImportError("Disabling attempt to import cutlass_library")
from cutlass_library.library import *
from cutlass_library.generator import *
from cutlass_library.heuristics_provider import *
except ImportError:
from library import *
from generator import *
from heuristics_provider import *
try:
from .sm90_utils import (
get_valid_schedules,
generate_data_types_from_math_instruction,
fix_alignments,
)
except ImportError:
from sm90_utils import (
get_valid_schedules,
generate_data_types_from_math_instruction,
fix_alignments,
)
_LOGGER = logging.getLogger(__name__)
dtype_map = {v: k for k, v in DataTypeNames.items()}
def serialize_heuristics_results_to_json(problems_with_configs, outfile_path):
"""
Utilitiy function to write heuristics results to a json file for debug
args:
problems_with_configs: List of problems provided to the heuristic, with a list of operations added to each problem dict
outfile_path: Outfile path
returns:
None
"""
pc_copy = problems_with_configs.copy()
for p in pc_copy:
for k, v in p.items():
if isinstance(v, DataType):
p[k] = DataTypeNames[v]
elif isinstance(v, LayoutType):
p[k] = ShortLayoutTypeNames[v]
configs = p['configs']
for c in configs:
for k, v in c.items():
if isinstance(v, DataType):
c[k] = DataTypeNames[v]
elif isinstance(v, LayoutType):
c[k] = ShortLayoutTypeNames[v]
with open(outfile_path, 'w') as f:
json.dump(pc_copy, f, indent=2)
def get_single_gemm_config(m, n, k, batch_count, layouts, dtypes, alignment_a, alignment_b, voidC=False, use_fast_acc=True, count=1, provider=None):
"""
Get heuristic-suggested GEMM kernel configurations for a single GEMM problem.
args:
m, n, k: GEMM dimensions
batch_count: batch count
layouts: tuple of layouts of type LayoutType
use_fast_acc: Use fast accumulation for FP8. Ignored for other precisions
count: Number of configs to return
provider: Heuristics provider to use
returns:
A list of dictionaries containing the suggested kernel configurations and additional info from the input required to define a Cutlass GemmOperation, with the following keys:
- 'cta_tile_m', 'cta_tile_m', 'cta_tile_k': CTA tile size
- 'instr_tile_m', 'instr_tile_n', 'instr_tile_k': Instruction tile size
- 'stages': kernel pipeline stage count
- 'cluster_m', 'cluster_n', 'cluster_k': cluster size
- 'layout_a', 'layout_b': input tensor layouts of type LayoutType
- 'alignment_a', 'alignment_b': input tensor alignments, in count of elements
- 'dtype_a', 'dtype_b', 'dtype_acc': dtypes of a, b, and accumulator, of type DataType
- 'swizzle_size' : suggested threadblock swizzle
- 'split_k_slices': number of partitions of the k dimension for splitK
- 'raster_order': raster order for CTAs over output tiles ('along_m' or 'along_n')
"""
if provider is None:
provider = MatmulHeuristics()
return provider.get_configs(m, n, k, batch_count, dtypes, layouts, alignment_a, alignment_b, voidC=voidC, use_fast_acc=use_fast_acc, count=count)
def get_gemm_configs(problems, provider=None, count=1):
"""
Get heuristic-suggested GEMM kernel configurations for a set of GEMM problems.
args:
problems: List of dictionaries describing GEMM problems with the following keys:
- 'm', 'n', 'k': Matrix dimensions (required)
- 'dtype_a': Data type of matrix A (required)
- 'dtype_b': Data type of matrix B (required)
- 'dtype_c': Data type of matrix C (default: None)
- 'dtype_d': Data type of matrix D (required)
- 'dtype_acc': Compute data type (default 'f32')
- 'layout': Operation layout (e.g. 'tnt')
- 'alignment_a': Memory access granularity of A, in units of elements (default: 16 bytes equivalent elements)
- 'alignment_b': Memory access granularity of B, in units of elements (default: 16 bytes equivalent elements)
- 'alpha': Scalar multiplier for A*B (default: 1.0)
- 'beta': Scalar multiplier for C (default: 0.0)
- 'batch_count': Number of GEMM operations in batch (default: 1)
- 'use_fast_acc': Enable fast accumulation for FP8 on Hopper (default: True)
provider: Heuristics provider to use
count: Number of configurations to return per problem (defualt: 1)
returns:
A copy of the input dictionary, with key `configs` added containing the selected gemm configs
"""
ret = []
for problem in problems:
problem = problem.copy()
try:
m = problem['m']
n = problem['n']
k = problem['k']
dtype_a = problem['dtype_a']
dtype_b = problem['dtype_b']
dtype_d = problem['dtype_d']
layout = problem['layout']
except KeyError as e:
_LOGGER.error(f"Missing required parameter {e} for problem {problem}")
raise
operation = problem.get('operation', 'gemm')
batch_count = problem.get('batch_count', 1)
dtype_acc = problem.get('dtype_acc', 'f32')
dtype_c = problem.get('dtype_c', None)
alpha = problem.get('alpha', 1.0)
beta = problem.get('beta', 0.0)
use_fast_acc = problem.get('use_fast_acc', True)
if operation != OperationKindNames[OperationKind.Gemm]:
raise ValueError(f"Unsupported operation {operation}")
if not (len(layout) == 3 and all(c in "nt" for c in layout)):
raise ValueError(f"layout must be a 3-character string containing only 'n' or 't', got {layout}")
layouts = tuple(LayoutType.RowMajor if l == 't' else LayoutType.ColumnMajor for l in layout)
try:
dtype_list = [dtype_a.lower(), dtype_b.lower(), dtype_acc.lower(), dtype_c.lower() if dtype_c is not None else dtype_d.lower(), dtype_d.lower()]
dtypes = tuple(dtype_map[dt] for dt in dtype_list)
except KeyError as dt:
_LOGGER.error(f"Unsupported data type: {dt}")
raise
alignment_a = problem.get('alignment_a', 128 // DataTypeSize[dtypes[0]])
alignment_b = problem.get('alignment_b', 128 // DataTypeSize[dtypes[1]])
configs = get_single_gemm_config(m, n, k, batch_count, layouts, dtypes, alignment_a, alignment_b, beta==0.0, use_fast_acc, count, provider)
problem['configs'] = configs
ret.append(problem)
return ret
def generate_sm100_from_heuristics_configs(manifest, cuda_version, kernel_configs):
"""
Generate CUTLASS operations based on the list of configs provided by the heuristic provider
args:
manifest: manifest argument to which to add operations, or None to just return the operations without a manifest (for pruning an existing manifest)
cuda_version: Cuda compiler version for generating cutlass operations
kernel_configs: list of configs generated by the heuristic
returns:
(configs, operations): a list of heuristic-provided kernel configs along with a one-to-one corresponding list of the generated operations
"""
min_cc = 100
max_cc = 101
if manifest is None:
# Use a dummy manifest so we can use existing CreateGemmOperator functions
manifest = Manifest()
configs = []
operations = []
for config in kernel_configs:
layout = ([config['layout_a'], config['alignment_a']], [config['layout_b'], config['alignment_b']], [config['layout_d'], 128 // DataTypeSize[config['dtype_d']]])
element_a, element_b, element_accumulator, element_c, element_d = config['dtype_a'], config['dtype_b'], config['dtype_acc'], config['dtype_c'], config['dtype_d']
# nvMMH assumes 2sm instruction for !(cluster_m % 2)
is_2sm = config['cluster_m'] % 2 == 0
instruction_shape = [(2 * config['cta_tile_m']) if is_2sm else config['cta_tile_m'], config['cta_tile_n'], config['cta_tile_k'] // 4]
math_instruction = MathInstruction(
instruction_shape,
element_a, element_b, element_accumulator,
OpcodeClass.TensorOp,
MathOperation.multiply_add
)
data_types = [
{
"a_type" : math_instruction.element_a,
"b_type" : math_instruction.element_b,
"c_type" : DataType.void if config['voidC'] else math_instruction.element_accumulator,
"d_type" : element_d,
"acc_type" : math_instruction.element_accumulator,
"epi_type" : math_instruction.element_accumulator,
}
]
tile_multiplier = (config['cluster_m'] // (2 if is_2sm else 1), config['cluster_n'], config['cluster_k'])
tile_description = TileDescription(
[instruction_shape[0] * tile_multiplier[0],
instruction_shape[1] * tile_multiplier[1],
instruction_shape[2] * 4 * tile_multiplier[2]],
0,
[4,1,1],
math_instruction,
min_cc,
max_cc,
cluster_shape=(config['cluster_m'], config['cluster_n'], config['cluster_k'])
)
schedules = []
if is_2sm:
schedules.append([KernelScheduleType.TmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm])
else:
schedules.append([KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm])
for o in CreateGemmUniversal3xOperator(manifest, [layout], [tile_description], data_types, schedules, tile_schedulers=[TileSchedulerType.Default, TileSchedulerType.StreamK], gemm_kind=GemmKind.Universal3x):
configs.append(config)
operations.append(o)
return configs, operations
def generate_sm90_from_heuristics_configs(manifest, cuda_version, kernel_configs):
"""
Generate CUTLASS operations based on the list of configs provided by the heuristic provider
args:
manifest: manifest argument to which to add operations, or None to just return the operations without a manifest (for pruning an existing manifest)
cuda_version: Cuda compiler version for generating cutlass operations
kernel_configs: list of configs generated by the heuristic
returns:
(configs, operations): a list of heuristic-provided kernel configs along with a one-to-one corresponding list of the generated operations
"""
min_cc, max_cc = 90, 90
if manifest is None:
# Use a dummy manifest so we can use existing CreateGemmOperator functions
manifest = Manifest()
configs = []
operations = []
for config in kernel_configs:
is_aligned = (config['alignment_a'] * DataTypeSize[config['dtype_a']] >= 128) and (config['alignment_b'] * DataTypeSize[config['dtype_b']] >= 128)
layout = ([config['layout_a'], config['alignment_a']], [config['layout_b'], config['alignment_b']], [LayoutType.ColumnMajor, 1])
element_a, element_b, element_accumulator, element_c, element_d = config['dtype_a'], config['dtype_b'], config['dtype_acc'], config['dtype_c'], config['dtype_d']
# instr shape and warp config are unused for emitting 3x collective builder code
dummy_instr_shape = [0, 0, 0]
math_instruction = MathInstruction(
dummy_instr_shape,
element_a, element_b, element_accumulator,
OpcodeClass.TensorOp,
MathOperation.multiply_add
)
data_types = generate_data_types_from_math_instruction(math_instruction, element_source=element_c, element_dest=element_d)
if is_aligned:
layout = fix_alignments(data_types, layout, alignment_bits=128)
# instr shape and warp config are unused for emitting 3x collective builder code
dummy_warp_count = [0, 0, 0]
tile_description = TileDescription(
[config['cta_tile_m'], config['cta_tile_n'], config['cta_tile_k']],
0,
dummy_warp_count,
math_instruction,
min_cc,
max_cc,
cluster_shape=(config['cluster_m'], config['cluster_n'], config['cluster_k'])
)
schedules, stream_k_schedules = get_valid_schedules(
tile_description=tile_description,
cuda_version=cuda_version,
is_aligned=is_aligned,
data_types=data_types,
instantiation_level=9000, # don't prune schedules: we didn't get any schedule suggestion from the heuristic
layout=layout,
gemm_kind=GemmKind.Universal3x,
enable_fp8_fast_acc=config['use_fast_acc']
)
if len(schedules):
for o in CreateGemmUniversal3xOperator(manifest, [layout], [tile_description], data_types, schedules, gemm_kind=GemmKind.Universal3x):
configs.append(config)
operations.append(o)
if len(stream_k_schedules):
for o in CreateGemmUniversal3xOperator(manifest, [layout], [tile_description], data_types,
stream_k_schedules,
tile_schedulers=[TileSchedulerType.StreamK]):
configs.append(config)
operations.append(o)
return configs, operations
def filter_manifest_and_write_heuristics_file(manifest, args):
"""
Prune a manifest according to heuristics suggestions from the problems file
args:
manifest: Cutlass manifest to prune
args: generator.py args, requires:
- args.heuristics_problems_file
- args.heuristics_gpu
- args.heuristics_testlist_file
returns:
A list of dictionaries, each of which has information about an operation and a problem from the input problems
"""
heuristics_problems = []
with open(args.heuristics_problems_file, 'r') as f:
heuristics_problems = json.load(f)
gpu = None if (args.heuristics_gpu == "auto" or args.heuristics_gpu == "") else args.heuristics_gpu
mmh = MatmulHeuristics(gpu=gpu)
if any(('100' in arch) for arch in args.architectures.split(';')):
mmh.set_cta_div_n(64)
problems_with_configs = get_gemm_configs(heuristics_problems, provider=mmh, count=args.heuristics_configs_per_problem)
all_configs_and_operations = []
operations = []
for problem in problems_with_configs:
if any('90' in arch for arch in args.architectures.split(';')):
problem_configs, problem_operations = generate_sm90_from_heuristics_configs(None if args.heuristics_restrict_kernels else manifest, args.cuda_version, problem['configs'])
if any(('100' in arch) or ('101' in arch) for arch in args.architectures.split(';')):
problem_configs, problem_operations = generate_sm100_from_heuristics_configs(None if args.heuristics_restrict_kernels else manifest, args.cuda_version, problem['configs'])
operations += problem_operations
problem_without_configs = {k: v for k, v in problem.items() if k != 'configs'}
with_problem_size = [{'operation_name': o.procedural_name(), **problem_without_configs, **c} for c, o in zip(problem_configs, problem_operations)]
all_configs_and_operations += with_problem_size
for operation in operations:
manifest.add_kernel_filter(f"^{operation.procedural_name()}$")
if not all_configs_and_operations:
raise Exception("No valid configurations generated")
write_profiler_testlist_to_csv(all_configs_and_operations, args.heuristics_testlist_file)
return all_configs_and_operations
def write_profiler_testlist_to_csv(configs_list, outfile_path):
"""
Write a list of configs to a testlist to be consumed by cutlass_profiler
args:
configs_list: List of kernel configs along with runtime arguments and any other columns to include in the CSV, expressed as a list of dictionaries
outfile_path: Outfile path
returns:
None
"""
profiler_testlist = configs_list.copy()
for c in profiler_testlist:
for k, v in c.items():
if isinstance(v, DataType):
c[k] = DataTypeNames[v]
elif isinstance(v, LayoutType):
c[k] = ShortLayoutTypeNames[v]
with open(outfile_path, mode='w', newline='') as ofile:
k_names = profiler_testlist[0].keys()
writer = csv.DictWriter(ofile, fieldnames=k_names)
writer.writeheader()
writer.writerows(profiler_testlist)

View File

@@ -0,0 +1,168 @@
#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Providers for kernel selection heuristics
"""
import sys
import os
import glob
import logging
import ctypes
import functools
from library import DataType, LayoutType
class MatmulHeuristics:
def __init__(self, gpu = None):
import nvMatmulHeuristics
self.mmh_lib = nvMatmulHeuristics
self.gpu = gpu
if 'CUTLASS_NVMMH_SO_PATH' in os.environ:
nvmmhInterfaceEx = functools.partial(self.mmh_lib.NvMatmulHeuristicsInterfaceEx, path=os.environ['CUTLASS_NVMMH_SO_PATH'])
else:
nvmmhInterfaceEx = self.mmh_lib.NvMatmulHeuristicsInterfaceEx
self.lh = nvmmhInterfaceEx(
backend=self.mmh_lib.NvMatmulHeuristicsTarget["CUTLASS3"],
flags=self.mmh_lib.NvMatmulHeuristicsFlags.PERF_MODEL_BASED_AUTO_TUNING,
load_discovery_implicitly=True,
gpu=self.mmh_lib.NvMatmulHeuristicsNvidiaGpu[self.gpu] if self.gpu else None
)
self.backend = self.lh.createBackend(self.mmh_lib.NvMatmulHeuristicsTarget["CUTLASS3"])
def _layout_from_cutlass(self, layouts):
assert(len(layouts)==3)
full_layout_str = ''.join('t' if l == LayoutType.RowMajor else 'n' for l in layouts)
input_layouts = full_layout_str[:2].upper()
lh_layout = input_layouts + '_' + str("ROW_MAJOR" if full_layout_str[-1]=='t' else "COL_MAJOR")
return self.mmh_lib.NvMatmulHeuristicsMatmulLayout[lh_layout]
def _precision_from_cutlass_dtypes(self, dtypes):
dtype_to_cublas = {
DataType.f64: 'D',
DataType.f32: 'S',
DataType.f16: 'H',
DataType.bf16: 'T',
DataType.e4m3: 'Q',
DataType.e5m2: 'R',
DataType.s32: 'I',
DataType.s8: 'B',
}
dtype_a, dtype_b, dtype_compute, dtype_c, dtype_d = dtypes
a_c = dtype_to_cublas[dtype_a]
if a_c.lower() != 'q':
return a_c + dtype_to_cublas[dtype_compute] + dtype_to_cublas[dtype_d]
else:
return a_c + dtype_to_cublas[dtype_b] + dtype_to_cublas[dtype_c] + dtype_to_cublas[dtype_compute] + dtype_to_cublas[dtype_d]
def set_cta_div_n(self, div_n):
cta_n_div_requirement = ctypes.c_int(div_n)
self.lh.setBackendValueProperty(
self.backend,
self.mmh_lib.NvMatmulHeuristicsBackendProperty.CTA_TILE_N_DIV_REQUIREMENT,
ctypes.byref(cta_n_div_requirement),
ctypes.sizeof(cta_n_div_requirement)
)
def set_cta_div_m(self, div_m):
cta_m_div_requirement = ctypes.c_int(div_m)
self.lh.setBackendValueProperty(
self.backend,
self.mmh_lib.NvMatmulHeuristicsBackendProperty.CTA_TILE_M_DIV_REQUIREMENT,
ctypes.byref(cta_m_div_requirement),
ctypes.sizeof(cta_m_div_requirement)
)
def get_configs(self, m, n, k, batch_count, dtypes, layouts, align_a, align_b, voidC=False, use_fast_acc=True, count=1):
if use_fast_acc:
disable_fast_acc_for_fp8 = ctypes.c_int(0)
else:
disable_fast_acc_for_fp8 = ctypes.c_int(1)
self.lh.setBackendValueProperty(
self.backend,
self.mmh_lib.NvMatmulHeuristicsBackendProperty.DISABLE_FAST_ACC_FOR_FP8,
ctypes.byref(disable_fast_acc_for_fp8),
ctypes.sizeof(disable_fast_acc_for_fp8)
)
precision = self._precision_from_cutlass_dtypes(dtypes)
layout = self._layout_from_cutlass(layouts)
matmul_problem = self.lh.makeNvMatmulHeuristicsProblem(m, n, k, layout, batch_count)
configs = self.lh.getEx(matmul_problem, count, self.backend, precision=precision)
ret = []
for c in configs:
kernel = c['kernel']
problem = c['problem']
r = {}
r['estimated_runtime'] = c['runtime']
r['cta_tile_m'] = kernel.cta_tile_m
r['cta_tile_n'] = kernel.cta_tile_n
r['cta_tile_k'] = kernel.cta_tile_k
r['instr_tile_m'] = kernel.instr_tile_m
r['instr_tile_n'] = kernel.instr_tile_n
r['instr_tile_k'] = kernel.instr_tile_k
r['warp_tile_m'] = kernel.warp_tile_m
r['warp_tile_n'] = kernel.warp_tile_n
r['warp_tile_k'] = kernel.warp_tile_k
r['cluster_m'] = kernel.cluster_m
r['cluster_n'] = kernel.cluster_n
r['cluster_k'] = 1
r['layout_a'] = layouts[0]
r['layout_b'] = layouts[1]
r['layout_d'] = layouts[2]
r['dtype_a'] = dtypes[0]
r['dtype_b'] = dtypes[1]
r['dtype_acc'] = dtypes[2]
r['dtype_c'] = dtypes[3]
r['dtype_d'] = dtypes[4]
r['alignment_a'] = align_a
r['alignment_b'] = align_b
r['swizzle_size'] = kernel.swizzle_factor
r['raster_order'] = 'along_m' if kernel.cta_order==0 else 'along_n'
r['split_k_slices'] = kernel.split_k
r['use_fast_acc'] = use_fast_acc
r['voidC'] = voidC
ret.append(r)
return ret

View File

@@ -546,6 +546,22 @@ class KernelScheduleType(enum.Enum):
Nvf4TmaWarpSpecialized1SmSm100 = enum_auto()
Nvf4TmaWarpSpecialized2SmSm100 = enum_auto()
# FP4 Ultra
BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103 = enum_auto()
BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103 = enum_auto()
BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103 = enum_auto()
BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103 = enum_auto()
BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch = enum_auto()
BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch = enum_auto()
BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch = enum_auto()
BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch = enum_auto()
BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch = enum_auto()
BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch = enum_auto()
BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch = enum_auto()
BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch = enum_auto()
Mxf8f6f4TmaWarpSpecializedCooperativeSm120 = enum_auto()
Mxf8f6f4TmaWarpSpecializedPingpongSm120 = enum_auto()
Nvf4TmaWarpSpecializedCooperativeSm120 = enum_auto()
@@ -603,6 +619,22 @@ KernelScheduleTag = {
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100',
KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100',
# FP4 Ultra
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103',
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103',
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103',
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103',
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch',
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch',
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch',
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch',
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch',
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch',
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch',
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch',
KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative',
KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum',
KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong',
@@ -677,6 +709,21 @@ KernelScheduleSuffixes = {
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: '_o_vs16_1sm',
KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: '_o_vs16_2sm',
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: '_o_vs16_1sm',
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: '_o_vs16_2sm',
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: '_o_vs32_1sm',
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: '_o_vs32_2sm',
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch: '_o_vs16_1sm_nopf',
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch: '_o_vs16_2sm_nopf',
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch: '_o_vs32_1sm_nopf',
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch: '_o_vs32_2sm_nopf',
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch: '_o_vs16_1sm_tmapf',
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: '_o_vs16_2sm_tmapf',
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: '_o_vs32_1sm_tmapf',
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: '_o_vs32_2sm_tmapf',
KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative: '_warpspecialized_cooperative',
KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: '_warpspecialized_cooperative_fp8_fastaccum',
KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong: '_warpspecialized_pingpong',
@@ -713,8 +760,12 @@ class EpilogueScheduleType(enum.Enum):
PtrArrayNoSmemWarpSpecialized = enum_auto()
NoSmemWarpSpecialized1Sm = enum_auto()
NoSmemWarpSpecialized2Sm = enum_auto()
FastF32NoSmemWarpSpecialized1Sm = enum_auto()
FastF32NoSmemWarpSpecialized2Sm = enum_auto()
PtrArrayNoSmemWarpSpecialized1Sm = enum_auto()
PtrArrayNoSmemWarpSpecialized2Sm = enum_auto()
PtrArrayFastF32NoSmemWarpSpecialized1Sm = enum_auto()
PtrArrayFastF32NoSmemWarpSpecialized2Sm = enum_auto()
TmaWarpSpecialized = enum_auto()
TmaWarpSpecializedCooperative = enum_auto()
TmaWarpSpecialized1Sm = enum_auto()
@@ -732,8 +783,12 @@ EpilogueScheduleTag = {
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized',
EpilogueScheduleType.NoSmemWarpSpecialized1Sm: 'cutlass::epilogue::NoSmemWarpSpecialized1Sm',
EpilogueScheduleType.NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::NoSmemWarpSpecialized2Sm',
EpilogueScheduleType.FastF32NoSmemWarpSpecialized1Sm: 'cutlass::epilogue::FastF32NoSmemWarpSpecialized1Sm',
EpilogueScheduleType.FastF32NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::FastF32NoSmemWarpSpecialized2Sm',
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized1Sm',
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized2Sm',
EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayFastF32NoSmemWarpSpecialized1Sm',
EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayFastF32NoSmemWarpSpecialized2Sm',
EpilogueScheduleType.TmaWarpSpecialized: 'cutlass::epilogue::TmaWarpSpecialized',
EpilogueScheduleType.TmaWarpSpecializedCooperative: 'cutlass::epilogue::TmaWarpSpecializedCooperative',
EpilogueScheduleType.TmaWarpSpecialized1Sm: 'cutlass::epilogue::TmaWarpSpecialized1Sm',
@@ -752,8 +807,12 @@ EpilogueScheduleSuffixes = {
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized: '_epi_nosmem',
EpilogueScheduleType.NoSmemWarpSpecialized1Sm: '_epi_nosmem',
EpilogueScheduleType.NoSmemWarpSpecialized2Sm: '_epi_nosmem',
EpilogueScheduleType.FastF32NoSmemWarpSpecialized1Sm: '_epi_nosmem_fastf32',
EpilogueScheduleType.FastF32NoSmemWarpSpecialized2Sm: '_epi_nosmem_fastf32',
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm: '_epi_nosmem',
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm: '_epi_nosmem',
EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized1Sm: '_epi_nosmem_fastf32',
EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized2Sm: '_epi_nosmem_fastf32',
EpilogueScheduleType.TmaWarpSpecialized: '_epi_tma',
EpilogueScheduleType.TmaWarpSpecializedCooperative: '_epi_tma',
EpilogueScheduleType.TmaWarpSpecialized1Sm: '',

View File

@@ -526,44 +526,49 @@ class Manifest:
if args.filter_by_cc in ['false', 'False', '0']:
self.filter_by_cc = False
if args.operations == 'all':
self.operations_enabled = []
else:
operations_list = [
OperationKind.Gemm
, OperationKind.Conv2d
, OperationKind.Conv3d
, OperationKind.RankK
, OperationKind.Trmm
, OperationKind.Symm
]
self.operations_enabled = [x for x in operations_list if OperationKindNames[x] in args.operations.split(',')]
if args.operations == 'all':
self.operations_enabled = []
else:
operations_list = [
OperationKind.Gemm
, OperationKind.Conv2d
, OperationKind.Conv3d
, OperationKind.RankK
, OperationKind.Trmm
, OperationKind.Symm
]
self.operations_enabled = [x for x in operations_list if OperationKindNames[x] in args.operations.split(',')]
if args.kernels == 'all':
self.kernel_names = []
else:
self.kernel_names = [x for x in args.kernels.split(',') if x != '']
if args.kernels == 'all':
self.kernel_names = []
else:
self.kernel_names = [x for x in args.kernels.split(',') if x != '']
self.ignore_kernel_names = [x for x in args.ignore_kernels.split(',') if x != '']
self.exclude_kernel_names = [x for x in args.exclude_kernels.split(',') if x != '']
self.ignore_kernel_names = [x for x in args.ignore_kernels.split(',') if x != '']
self.exclude_kernel_names = [x for x in args.exclude_kernels.split(',') if x != '']
if args.kernel_filter_file is None:
self.kernel_filter_list = []
else:
self.kernel_filter_list = self.get_kernel_filters(args.kernel_filter_file)
_LOGGER.debug("Using {filter_count} kernel filters from {filter_file}".format(
filter_count = len(self.kernel_filter_list),
filter_file = args.kernel_filter_file))
if args.kernel_filter_file is None:
self.kernel_filter_list = []
else:
self.kernel_filter_list = self.get_kernel_filters(args.kernel_filter_file)
_LOGGER.debug("Using {filter_count} kernel filters from {filter_file}".format(
filter_count = len(self.kernel_filter_list),
filter_file = args.kernel_filter_file))
self.operation_count = 0
self.operations_by_name = {}
self.disable_full_archs_compilation = args.disable_full_archs_compilation
self.is_kernel_filter_set_to_all = args.instantiation_level == "max" and args.kernels != ''
self.instantiation_level = 0
try:
self.instantiation_level = int(args.instantiation_level)
except ValueError:
self.instantiation_level = 0
self.operation_count = 0
self.operations_by_name = {}
self.disable_full_archs_compilation = args.disable_full_archs_compilation
self.is_kernel_filter_set_to_all = args.instantiation_level == "max" and args.kernels != ''
self.instantiation_level = 0
try:
self.instantiation_level = int(args.instantiation_level)
except ValueError:
self.instantiation_level = 0
def add_kernel_filter(self, filter_str):
filter_re = re.compile(filter_str)
self.kernel_filter_list.append(filter_re)
def get_sm90_instantiation_level(self, pruned_level=0, default_level=111, exhaustive_level=9992):
# Non-negative integer which determines how many kernels are instantiated.

View File

@@ -407,7 +407,7 @@ def generate_tile_descriptions_sm90(math_instructions, is_aligned: bool, level:
def is_tile_desc_compatible_with_cooperative(tile_description):
# Cooperative kernels require a minimum CTA-M of 128
return tile_description.threadblock_shape[0] >= 128
return tile_description.threadblock_shape[0] % 128 == 0
def can_tile_desc_use_shmem_in_epilogue(tile_description, data_types):

View File

@@ -50,17 +50,17 @@ setup_pycute.perform_setup()
setup(
name='cutlass',
version='3.4.0',
name='cutlass_cppgen',
version='4.0.0',
description='CUTLASS Pythonic Interface',
package_dir={'': '.'},
packages=[
'cutlass',
'cutlass.emit',
'cutlass.op',
'cutlass.utils',
'cutlass.backend',
'cutlass.backend.utils'
'cutlass_cppgen',
'cutlass_cppgen.emit',
'cutlass_cppgen.op',
'cutlass_cppgen.utils',
'cutlass_cppgen.backend',
'cutlass_cppgen.backend.utils'
],
setup_requires=['pybind11'],
install_requires=[