mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-19 14:28:59 +00:00
v4.2 release. (#2587)
* Fix default cluster callback values to 1 to avoid profiler failure when these values are not set in command line. * v4.2 release.
This commit is contained in:
@@ -4,7 +4,7 @@
|
||||
|
||||
This directory contains Python packages that are associated with CUTLASS:
|
||||
|
||||
* `cutlass`: the CUTLASS Python interface, which enables one to compile and run CUTLASS kernels from within Python
|
||||
* `cutlass_cppgen`: the CUTLASS Python interface, which enables one to compile and run CUTLASS kernels from within Python. Note that this was previously named `cutlass`, but was renamed to disambiguate with the CuTe Python DSL.
|
||||
* `cutlass_library`: utilities used for enumerating and emitting C++ code for CUTLASS kernels
|
||||
|
||||
## CUTLASS Python Interface
|
||||
|
||||
@@ -119,8 +119,8 @@ def set_log_level(level: int):
|
||||
|
||||
set_log_level(logging.ERROR)
|
||||
|
||||
from cutlass.library_defaults import OptionRegistry
|
||||
from cutlass.backend.utils.device import device_cc
|
||||
from cutlass_cppgen.library_defaults import OptionRegistry
|
||||
from cutlass_cppgen.backend.utils.device import device_cc
|
||||
|
||||
this._option_registry = None
|
||||
def get_option_registry():
|
||||
@@ -135,14 +135,14 @@ def get_option_registry():
|
||||
|
||||
this.__version__ = '4.1.0'
|
||||
|
||||
from cutlass.backend import create_memory_pool
|
||||
from cutlass.emit.pytorch import pytorch
|
||||
from cutlass.op.gemm import Gemm
|
||||
from cutlass.op.conv import Conv2d, Conv2dFprop, Conv2dDgrad, Conv2dWgrad
|
||||
from cutlass.op.gemm_grouped import GroupedGemm
|
||||
from cutlass.op.op import OperationBase
|
||||
from cutlass.backend.evt.ir.tensor import Tensor
|
||||
from cutlass.utils.lazy_import import lazy_import
|
||||
from cutlass_cppgen.backend import create_memory_pool
|
||||
from cutlass_cppgen.emit.pytorch import pytorch
|
||||
from cutlass_cppgen.op.gemm import Gemm
|
||||
from cutlass_cppgen.op.conv import Conv2d, Conv2dFprop, Conv2dDgrad, Conv2dWgrad
|
||||
from cutlass_cppgen.op.gemm_grouped import GroupedGemm
|
||||
from cutlass_cppgen.op.op import OperationBase
|
||||
from cutlass_cppgen.backend.evt.ir.tensor import Tensor
|
||||
from cutlass_cppgen.utils.lazy_import import lazy_import
|
||||
|
||||
|
||||
this.memory_pool = None
|
||||
|
||||
@@ -30,19 +30,19 @@
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from cutlass.backend.arguments import *
|
||||
from cutlass.backend.c_types import *
|
||||
from cutlass.backend.compiler import ArtifactManager
|
||||
from cutlass.backend.conv2d_operation import *
|
||||
from cutlass.backend.epilogue import *
|
||||
from cutlass.backend.frontend import *
|
||||
from cutlass.backend.gemm_operation import *
|
||||
from cutlass.backend.library import *
|
||||
from cutlass.backend.memory_manager import PoolMemoryManager, create_memory_pool
|
||||
from cutlass.backend.operation import *
|
||||
from cutlass.backend.reduction_operation import *
|
||||
from cutlass.backend.type_hint import *
|
||||
from cutlass.backend.utils import *
|
||||
from cutlass.backend.utils.device import device_cc
|
||||
from cutlass_cppgen.backend.arguments import *
|
||||
from cutlass_cppgen.backend.c_types import *
|
||||
from cutlass_cppgen.backend.compiler import ArtifactManager
|
||||
from cutlass_cppgen.backend.conv2d_operation import *
|
||||
from cutlass_cppgen.backend.epilogue import *
|
||||
from cutlass_cppgen.backend.frontend import *
|
||||
from cutlass_cppgen.backend.gemm_operation import *
|
||||
from cutlass_cppgen.backend.library import *
|
||||
from cutlass_cppgen.backend.memory_manager import PoolMemoryManager, create_memory_pool
|
||||
from cutlass_cppgen.backend.operation import *
|
||||
from cutlass_cppgen.backend.reduction_operation import *
|
||||
from cutlass_cppgen.backend.type_hint import *
|
||||
from cutlass_cppgen.backend.utils import *
|
||||
from cutlass_cppgen.backend.utils.device import device_cc
|
||||
|
||||
compiler = ArtifactManager()
|
||||
|
||||
@@ -33,16 +33,16 @@
|
||||
from math import prod
|
||||
from typing import Union
|
||||
|
||||
from cutlass.utils.lazy_import import lazy_import
|
||||
from cutlass_cppgen.utils.lazy_import import lazy_import
|
||||
|
||||
cuda = lazy_import("cuda.cuda")
|
||||
cudart = lazy_import("cuda.cudart")
|
||||
import numpy as np
|
||||
|
||||
import cutlass
|
||||
from cutlass.backend.frontend import CupyFrontend, NumpyFrontend, TorchFrontend
|
||||
from cutlass.backend.memory_manager import DevicePtrWrapper
|
||||
from cutlass.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_tensor
|
||||
import cutlass_cppgen
|
||||
from cutlass_cppgen.backend.frontend import CupyFrontend, NumpyFrontend, TorchFrontend
|
||||
from cutlass_cppgen.backend.memory_manager import DevicePtrWrapper
|
||||
from cutlass_cppgen.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_tensor
|
||||
|
||||
|
||||
class ArgumentBase:
|
||||
@@ -122,7 +122,7 @@ class ArgumentBase:
|
||||
Frees allocated device-side memory
|
||||
"""
|
||||
# Free any device memory allocated manually
|
||||
if not cutlass.use_rmm:
|
||||
if not cutlass_cppgen.use_rmm:
|
||||
for name, buf in self.buffers.items():
|
||||
if isinstance(buf, DevicePtrWrapper):
|
||||
err, = cudart.cudaFree(buf.ptr)
|
||||
|
||||
@@ -37,7 +37,7 @@ from cutlass_library import (
|
||||
KernelScheduleType,
|
||||
TileSchedulerType
|
||||
)
|
||||
from cutlass.backend.library import DataTypeSizeBytes
|
||||
from cutlass_cppgen.backend.library import DataTypeSizeBytes
|
||||
|
||||
|
||||
class GemmCoord_(ctypes.Structure):
|
||||
|
||||
@@ -37,17 +37,17 @@ import sqlite3
|
||||
import subprocess
|
||||
import tempfile
|
||||
|
||||
from cutlass.utils.lazy_import import lazy_import
|
||||
from cutlass_cppgen.utils.lazy_import import lazy_import
|
||||
cuda = lazy_import("cuda.cuda")
|
||||
cudart = lazy_import("cuda.cudart")
|
||||
nvrtc = lazy_import("cuda.nvrtc")
|
||||
from cutlass_library import SubstituteTemplate
|
||||
|
||||
import cutlass
|
||||
from cutlass import CACHE_FILE, CUTLASS_PATH, cuda_install_path, logger
|
||||
from cutlass.backend.gemm_operation import GemmOperationUniversal
|
||||
from cutlass.backend.library import ApiVersion
|
||||
from cutlass.backend.utils.device import device_cc
|
||||
import cutlass_cppgen
|
||||
from cutlass_cppgen import CACHE_FILE, CUTLASS_PATH, cuda_install_path, logger
|
||||
from cutlass_cppgen.backend.gemm_operation import GemmOperationUniversal
|
||||
from cutlass_cppgen.backend.library import ApiVersion
|
||||
from cutlass_cppgen.backend.utils.device import device_cc
|
||||
|
||||
IncludeTemplate = r"""#include "${include}"
|
||||
"""
|
||||
@@ -93,7 +93,7 @@ class CompilationOptions:
|
||||
opts.append(f"--include-path={incl}")
|
||||
|
||||
arch_flag = f"-arch=sm_{self.arch}"
|
||||
if self.arch == 90 and int(cutlass.nvcc_version().split('.')[0]) >= 12:
|
||||
if self.arch == 90 and int(cutlass_cppgen.nvcc_version().split('.')[0]) >= 12:
|
||||
arch_flag += "a"
|
||||
opts.append(arch_flag)
|
||||
|
||||
@@ -366,7 +366,7 @@ class ArtifactManager:
|
||||
CUTLASS_PATH + "/python/cutlass/cpp/include",
|
||||
]
|
||||
|
||||
cutlass.initialize_cuda_context()
|
||||
cutlass_cppgen.initialize_cuda_context()
|
||||
arch = device_cc()
|
||||
|
||||
host_compile_options = CompilationOptions(
|
||||
|
||||
@@ -34,7 +34,7 @@ from __future__ import annotations
|
||||
import ctypes
|
||||
from typing import Union
|
||||
|
||||
from cutlass.utils.lazy_import import lazy_import
|
||||
from cutlass_cppgen.utils.lazy_import import lazy_import
|
||||
cuda = lazy_import("cuda.cuda")
|
||||
from cutlass_library import SubstituteTemplate
|
||||
import numpy as np
|
||||
@@ -65,17 +65,17 @@ from cutlass_library import (
|
||||
get_complex_from_real,
|
||||
)
|
||||
|
||||
from cutlass.backend.arguments import ArgumentBase
|
||||
from cutlass.backend.c_types import dim3_, get_conv2d_arguments
|
||||
from cutlass.backend.library import (
|
||||
from cutlass_cppgen.backend.arguments import ArgumentBase
|
||||
from cutlass_cppgen.backend.c_types import dim3_, get_conv2d_arguments
|
||||
from cutlass_cppgen.backend.library import (
|
||||
EmissionType,
|
||||
TensorDescription,
|
||||
TileDescription,
|
||||
)
|
||||
from cutlass.backend.memory_manager import device_mem_alloc
|
||||
from cutlass.backend.operation import ExecutableOperation, LaunchConfiguration
|
||||
from cutlass.backend.utils.device import to_device_ptr
|
||||
from cutlass.shape import GemmCoord
|
||||
from cutlass_cppgen.backend.memory_manager import device_mem_alloc
|
||||
from cutlass_cppgen.backend.operation import ExecutableOperation, LaunchConfiguration
|
||||
from cutlass_cppgen.backend.utils.device import to_device_ptr
|
||||
from cutlass_cppgen.shape import GemmCoord
|
||||
|
||||
|
||||
class Conv2dArguments(ArgumentBase):
|
||||
@@ -84,9 +84,9 @@ class Conv2dArguments(ArgumentBase):
|
||||
user-provide tensors into the kernel's argument.
|
||||
|
||||
:param operation: the Conv2d operation to take the argument
|
||||
:type operation: :class:`cutlass.backend.Conv2dOperation`
|
||||
:type operation: :class:`cutlass_cppgen.backend.Conv2dOperation`
|
||||
:param problem_size: the Conv2d problem size
|
||||
:type problem_size: :class:`cutlass.shape.Conv2dProblemSize`
|
||||
:type problem_size: :class:`cutlass_cppgen.shape.Conv2dProblemSize`
|
||||
:param A: tensor A
|
||||
:type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
|
||||
:param B: tensor B
|
||||
@@ -98,7 +98,7 @@ class Conv2dArguments(ArgumentBase):
|
||||
:param split_k_mode: conv2d split K mode, defaults to cutlass_library.library.SplitKMode.Serial
|
||||
:type split_k_mode: cutlass_library.library.SplitKMode, optional
|
||||
:param output_op: output operator, optional
|
||||
:type output_op: :class:`cutlass.backend.LinearCombinationFunctorArguments`
|
||||
:type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments`
|
||||
:param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
|
||||
:type stream: :class:`cuda.cuda.CUstream`
|
||||
"""
|
||||
@@ -380,19 +380,19 @@ class Conv2dOperation:
|
||||
:type arch: int
|
||||
|
||||
:param tile_description: tile description
|
||||
:type tile_description: :class:`cutlass.backend.TileDescription`
|
||||
:type tile_description: :class:`cutlass_cppgen.backend.TileDescription`
|
||||
|
||||
:param A: tensor A description
|
||||
:type A: :class:`cutlass.backend.TensorDescription`
|
||||
:type A: :class:`cutlass_cppgen.backend.TensorDescription`
|
||||
|
||||
:param B: tensor B description
|
||||
:type B: :class:`cutlass.backend.TensorDescription`
|
||||
:type B: :class:`cutlass_cppgen.backend.TensorDescription`
|
||||
|
||||
:param C: tensor C description
|
||||
:type C: :class:`cutlass.backend.TensorDescription`
|
||||
:type C: :class:`cutlass_cppgen.backend.TensorDescription`
|
||||
|
||||
:param D: tensor D description
|
||||
:type D: :class:`cutlass.backend.TensorDescription`
|
||||
:type D: :class:`cutlass_cppgen.backend.TensorDescription`
|
||||
|
||||
:param element_epilogue: element type for computation in epilogue \
|
||||
:type element_epilogue: cutlass_library.library.DataType
|
||||
@@ -444,7 +444,7 @@ class Conv2dOperation:
|
||||
Launch the cuda kernel with input arguments
|
||||
|
||||
:param arguments: conv2d arguments
|
||||
:type arguments: :class:`cutlass.backend.Conv2dArguments`
|
||||
:type arguments: :class:`cutlass_cppgen.backend.Conv2dArguments`
|
||||
"""
|
||||
|
||||
# launch the kernel
|
||||
|
||||
@@ -36,10 +36,10 @@ from cutlass_library import SubstituteTemplate
|
||||
import numpy as np
|
||||
|
||||
from cutlass_library import DataType, DataTypeTag
|
||||
from cutlass.backend.c_types import MatrixCoord_, tuple_factory
|
||||
from cutlass.backend.frontend import NumpyFrontend
|
||||
from cutlass.backend.library import ActivationOp, ActivationOpTag
|
||||
from cutlass.utils.datatypes import is_numpy_tensor, is_torch_available, is_torch_tensor
|
||||
from cutlass_cppgen.backend.c_types import MatrixCoord_, tuple_factory
|
||||
from cutlass_cppgen.backend.frontend import NumpyFrontend
|
||||
from cutlass_cppgen.backend.library import ActivationOp, ActivationOpTag
|
||||
from cutlass_cppgen.utils.datatypes import is_numpy_tensor, is_torch_available, is_torch_tensor
|
||||
|
||||
dtype2ctype = {
|
||||
DataType.f16: ctypes.c_uint16,
|
||||
|
||||
@@ -30,5 +30,5 @@
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from cutlass.backend.evt.epilogue import EpilogueFunctorVisitor
|
||||
from cutlass.backend.evt.frontend import PythonASTFrontend
|
||||
from cutlass_cppgen.backend.evt.epilogue import EpilogueFunctorVisitor
|
||||
from cutlass_cppgen.backend.evt.frontend import PythonASTFrontend
|
||||
|
||||
@@ -30,7 +30,7 @@
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from cutlass.backend.evt.backend.sm80_emitter import Sm80Emitter
|
||||
import cutlass.backend.evt.backend.sm80_nodes as sm80_nodes
|
||||
from cutlass.backend.evt.backend.sm90_emitter import Sm90Emitter
|
||||
import cutlass.backend.evt.backend.sm90_nodes as sm90_nodes
|
||||
from cutlass_cppgen.backend.evt.backend.sm80_emitter import Sm80Emitter
|
||||
import cutlass_cppgen.backend.evt.backend.sm80_nodes as sm80_nodes
|
||||
from cutlass_cppgen.backend.evt.backend.sm90_emitter import Sm90Emitter
|
||||
import cutlass_cppgen.backend.evt.backend.sm90_nodes as sm90_nodes
|
||||
|
||||
@@ -35,7 +35,7 @@ Base class for Epilogue Visitor Emitter
|
||||
"""
|
||||
|
||||
from cutlass_library import DataTypeTag
|
||||
from cutlass.backend.evt.ir import TopoVisitorNode, DAGIR
|
||||
from cutlass_cppgen.backend.evt.ir import TopoVisitorNode, DAGIR
|
||||
|
||||
|
||||
class FusionCallbacks:
|
||||
|
||||
@@ -34,8 +34,8 @@
|
||||
Emitter for Sm80 Epilogue Visitor
|
||||
"""
|
||||
|
||||
from cutlass.backend.evt.backend.emitter_base import FusionCallbacks
|
||||
from cutlass.backend import GemmOperationUniversal
|
||||
from cutlass_cppgen.backend.evt.backend.emitter_base import FusionCallbacks
|
||||
from cutlass_cppgen.backend import GemmOperationUniversal
|
||||
|
||||
|
||||
class Sm80Emitter:
|
||||
|
||||
@@ -32,7 +32,7 @@
|
||||
|
||||
from cutlass_library import DataTypeSize, DataTypeTag
|
||||
|
||||
from cutlass.backend.evt.ir import (
|
||||
from cutlass_cppgen.backend.evt.ir import (
|
||||
# Load Node
|
||||
AccumulatorImpl,
|
||||
AuxLoadImpl,
|
||||
@@ -50,7 +50,7 @@ from cutlass.backend.evt.ir import (
|
||||
ScalarReductionImpl
|
||||
)
|
||||
|
||||
from cutlass.backend.library import (
|
||||
from cutlass_cppgen.backend.library import (
|
||||
FloatRoundStyleTag,
|
||||
FunctionalOp,
|
||||
op_tag,
|
||||
|
||||
@@ -35,8 +35,8 @@ Emitter for Sm90 Epilogue Visitor
|
||||
"""
|
||||
|
||||
from cutlass_library import DataTypeTag, EpilogueScheduleTag
|
||||
from cutlass.backend import GemmOperationUniversal
|
||||
from cutlass.backend.evt.backend.emitter_base import FusionCallbacks
|
||||
from cutlass_cppgen.backend import GemmOperationUniversal
|
||||
from cutlass_cppgen.backend.evt.backend.emitter_base import FusionCallbacks
|
||||
|
||||
|
||||
class CollectiveEpilogue:
|
||||
|
||||
@@ -33,7 +33,7 @@
|
||||
from pycute import product
|
||||
|
||||
from cutlass_library import DataTypeSize, DataTypeTag
|
||||
from cutlass.backend.evt.ir import (
|
||||
from cutlass_cppgen.backend.evt.ir import (
|
||||
# Load Node
|
||||
AccumulatorImpl,
|
||||
AuxLoadImpl,
|
||||
@@ -53,7 +53,7 @@ from cutlass.backend.evt.ir import (
|
||||
StoreNode,
|
||||
StoreDImpl,
|
||||
)
|
||||
from cutlass.backend.library import (
|
||||
from cutlass_cppgen.backend.library import (
|
||||
FloatRoundStyleTag,
|
||||
FunctionalOp,
|
||||
op_tag,
|
||||
|
||||
@@ -36,16 +36,16 @@ Epilogue Visitor interface for compiling, and running visitor-based epilogue.
|
||||
|
||||
import ctypes
|
||||
|
||||
from cutlass.utils.lazy_import import lazy_import
|
||||
from cutlass_cppgen.utils.lazy_import import lazy_import
|
||||
cuda = lazy_import("cuda.cuda")
|
||||
from cutlass_library import DataType
|
||||
import numpy as np
|
||||
|
||||
from cutlass.backend.epilogue import EpilogueFunctorBase
|
||||
import cutlass.backend.evt.backend
|
||||
from cutlass.backend.frontend import TensorFrontend
|
||||
from cutlass.utils.datatypes import is_numpy_tensor
|
||||
from cutlass.backend.evt.passes.util import cc_map
|
||||
from cutlass_cppgen.backend.epilogue import EpilogueFunctorBase
|
||||
import cutlass_cppgen.backend.evt.backend
|
||||
from cutlass_cppgen.backend.frontend import TensorFrontend
|
||||
from cutlass_cppgen.utils.datatypes import is_numpy_tensor
|
||||
from cutlass_cppgen.backend.evt.passes.util import cc_map
|
||||
|
||||
|
||||
class EpilogueFunctorVisitor(EpilogueFunctorBase):
|
||||
@@ -58,7 +58,7 @@ class EpilogueFunctorVisitor(EpilogueFunctorBase):
|
||||
"""
|
||||
def __init__(self, cc: int, visitor, element_compute=DataType.f32) -> None:
|
||||
# Type of Emitter based on CC
|
||||
self.emit_cls = getattr(cutlass.backend.evt.backend, f"Sm{cc_map[cc]}Emitter")
|
||||
self.emit_cls = getattr(cutlass_cppgen.backend.evt.backend, f"Sm{cc_map[cc]}Emitter")
|
||||
|
||||
# Visitor Types
|
||||
self.visitor = visitor
|
||||
|
||||
@@ -30,4 +30,4 @@
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from cutlass.backend.evt.frontend.python_ast import PythonASTFrontend
|
||||
from cutlass_cppgen.backend.evt.frontend.python_ast import PythonASTFrontend
|
||||
|
||||
@@ -37,14 +37,14 @@ Base class for Python EVT Frontend
|
||||
from typing import Union
|
||||
|
||||
from cutlass_library import DataType
|
||||
from cutlass.backend.evt.ir import (
|
||||
from cutlass_cppgen.backend.evt.ir import (
|
||||
ComputeNode,
|
||||
DAGIR,
|
||||
LayoutNode,
|
||||
LoadNode,
|
||||
StoreNode,
|
||||
)
|
||||
from cutlass.backend.evt.passes import (
|
||||
from cutlass_cppgen.backend.evt.passes import (
|
||||
EVTGraphDrawer,
|
||||
EVTPassManager,
|
||||
GetSmemSize,
|
||||
@@ -56,9 +56,9 @@ from cutlass.backend.evt.passes import (
|
||||
PassPreprocessRed,
|
||||
PassShapeTypePropagation,
|
||||
)
|
||||
from cutlass.backend.utils import device_cc
|
||||
from cutlass.epilogue.evt_ops import permute, reshape
|
||||
from cutlass.utils.datatypes import library_type
|
||||
from cutlass_cppgen.backend.utils import device_cc
|
||||
from cutlass_cppgen.epilogue.evt_ops import permute, reshape
|
||||
from cutlass_cppgen.utils.datatypes import library_type
|
||||
|
||||
|
||||
class EVTFrontendBase:
|
||||
|
||||
@@ -40,10 +40,10 @@ import textwrap
|
||||
|
||||
from cutlass_library import DataType
|
||||
|
||||
import cutlass
|
||||
from cutlass.backend.evt.frontend.frontend_base import EVTFrontendBase
|
||||
from cutlass.backend.epilogue import identity, relu, tanh, sigmoid, silu, hardswish, gelu
|
||||
from cutlass.backend.library import FunctionalOp
|
||||
import cutlass_cppgen
|
||||
from cutlass_cppgen.backend.evt.frontend.frontend_base import EVTFrontendBase
|
||||
from cutlass_cppgen.backend.epilogue import identity, relu, tanh, sigmoid, silu, hardswish, gelu
|
||||
from cutlass_cppgen.backend.library import FunctionalOp
|
||||
|
||||
|
||||
class PythonASTFrontend(EVTFrontendBase, ast.NodeVisitor):
|
||||
|
||||
@@ -30,10 +30,10 @@
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from cutlass.backend.evt.ir.compute_nodes import ComputeNode, ComputeImpl
|
||||
from cutlass.backend.evt.ir.dag_ir import DAGIR
|
||||
from cutlass.backend.evt.ir.layout_nodes import LayoutNode
|
||||
from cutlass.backend.evt.ir.load_nodes import (
|
||||
from cutlass_cppgen.backend.evt.ir.compute_nodes import ComputeNode, ComputeImpl
|
||||
from cutlass_cppgen.backend.evt.ir.dag_ir import DAGIR
|
||||
from cutlass_cppgen.backend.evt.ir.layout_nodes import LayoutNode
|
||||
from cutlass_cppgen.backend.evt.ir.load_nodes import (
|
||||
LoadNode,
|
||||
AccumulatorImpl,
|
||||
LoadSrcImpl,
|
||||
@@ -42,8 +42,8 @@ from cutlass.backend.evt.ir.load_nodes import (
|
||||
ColumnBroadcastImpl,
|
||||
ScalarBroadcastImpl
|
||||
)
|
||||
from cutlass.backend.evt.ir.node import TopoVisitorNode, NoOpImpl
|
||||
from cutlass.backend.evt.ir.store_nodes import (
|
||||
from cutlass_cppgen.backend.evt.ir.node import TopoVisitorNode, NoOpImpl
|
||||
from cutlass_cppgen.backend.evt.ir.store_nodes import (
|
||||
StoreNode,
|
||||
StoreDImpl,
|
||||
AuxStoreImpl,
|
||||
|
||||
@@ -34,8 +34,8 @@
|
||||
Python registration for compute nodes in EVT
|
||||
"""
|
||||
|
||||
from cutlass.backend.evt.ir.node import NodeBase, ImplBase
|
||||
from cutlass.backend.library import FloatRoundStyle
|
||||
from cutlass_cppgen.backend.evt.ir.node import NodeBase, ImplBase
|
||||
from cutlass_cppgen.backend.library import FloatRoundStyle
|
||||
|
||||
|
||||
class ComputeImplBase(ImplBase):
|
||||
|
||||
@@ -38,10 +38,10 @@ import networkx as nx
|
||||
|
||||
from cutlass_library import DataType
|
||||
|
||||
from cutlass.backend.evt.ir.compute_nodes import ComputeNode
|
||||
from cutlass.backend.evt.ir.node import NodeBase
|
||||
from cutlass.backend.library import ActivationOp
|
||||
from cutlass.backend.utils import device_cc
|
||||
from cutlass_cppgen.backend.evt.ir.compute_nodes import ComputeNode
|
||||
from cutlass_cppgen.backend.evt.ir.node import NodeBase
|
||||
from cutlass_cppgen.backend.library import ActivationOp
|
||||
from cutlass_cppgen.backend.utils import device_cc
|
||||
|
||||
|
||||
class DAGIR:
|
||||
|
||||
@@ -41,10 +41,10 @@ from copy import deepcopy
|
||||
from cutlass_library import LayoutType
|
||||
from pycute import product, flatten
|
||||
|
||||
import cutlass
|
||||
from cutlass.backend.evt.ir.layout_algorithm import _list_to_tuple, _tuple_to_list
|
||||
from cutlass.backend.evt.ir.node import NodeBase
|
||||
from cutlass.backend.evt.ir.tensor import Tensor
|
||||
import cutlass_cppgen
|
||||
from cutlass_cppgen.backend.evt.ir.layout_algorithm import _list_to_tuple, _tuple_to_list
|
||||
from cutlass_cppgen.backend.evt.ir.node import NodeBase
|
||||
from cutlass_cppgen.backend.evt.ir.tensor import Tensor
|
||||
|
||||
|
||||
class PermutationImpl:
|
||||
|
||||
@@ -36,9 +36,9 @@ Load nodes and implementations
|
||||
|
||||
import ctypes
|
||||
|
||||
from cutlass.backend.c_types import tuple_factory
|
||||
from cutlass.backend.epilogue import dtype2ctype, to_ctype_value
|
||||
from cutlass.backend.evt.ir.node import NodeBase, ImplBase
|
||||
from cutlass_cppgen.backend.c_types import tuple_factory
|
||||
from cutlass_cppgen.backend.epilogue import dtype2ctype, to_ctype_value
|
||||
from cutlass_cppgen.backend.evt.ir.node import NodeBase, ImplBase
|
||||
|
||||
|
||||
class LoadImplBase(ImplBase):
|
||||
|
||||
@@ -39,8 +39,8 @@ from re import sub
|
||||
|
||||
from cutlass_library import LayoutType
|
||||
|
||||
from cutlass.backend.evt.ir.layout_algorithm import _list_to_tuple, _reverse_tuple
|
||||
from cutlass.backend.evt.ir.tensor import Tensor
|
||||
from cutlass_cppgen.backend.evt.ir.layout_algorithm import _list_to_tuple, _reverse_tuple
|
||||
from cutlass_cppgen.backend.evt.ir.tensor import Tensor
|
||||
|
||||
|
||||
class ImplBase:
|
||||
@@ -170,7 +170,7 @@ class NodeBase:
|
||||
@property
|
||||
def tensor(self) -> Tensor:
|
||||
"""
|
||||
Return the output tensor (concept: cutlass.backend.evt.ir.tensor)
|
||||
Return the output tensor (concept: cutlass_cppgen.backend.evt.ir.tensor)
|
||||
"""
|
||||
return self._tensor
|
||||
|
||||
|
||||
@@ -38,11 +38,11 @@ import ctypes
|
||||
|
||||
from cutlass_library import DataType
|
||||
|
||||
from cutlass.backend.c_types import tuple_factory
|
||||
from cutlass.backend.epilogue import dtype2ctype, to_ctype_value
|
||||
from cutlass.backend.evt.ir.node import NodeBase, ImplBase, NoOpImpl
|
||||
from cutlass.backend.evt.ir.tensor import Tensor
|
||||
from cutlass.backend.library import FloatRoundStyle, FunctionalOp
|
||||
from cutlass_cppgen.backend.c_types import tuple_factory
|
||||
from cutlass_cppgen.backend.epilogue import dtype2ctype, to_ctype_value
|
||||
from cutlass_cppgen.backend.evt.ir.node import NodeBase, ImplBase, NoOpImpl
|
||||
from cutlass_cppgen.backend.evt.ir.tensor import Tensor
|
||||
from cutlass_cppgen.backend.library import FloatRoundStyle, FunctionalOp
|
||||
|
||||
|
||||
class StoreImplBase(ImplBase):
|
||||
@@ -249,7 +249,7 @@ class StoreNode(NodeBase):
|
||||
@property
|
||||
def store_tensor(self) -> Tensor:
|
||||
"""
|
||||
Return the output tensor (concept: cutlass.backend.evt.ir.tensor)
|
||||
Return the output tensor (concept: cutlass_cppgen.backend.evt.ir.tensor)
|
||||
"""
|
||||
return self._store_tensor
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ High-level class for tensor
|
||||
|
||||
from cutlass_library import LayoutType
|
||||
|
||||
from cutlass.backend.evt.ir.layout_algorithm import (
|
||||
from cutlass_cppgen.backend.evt.ir.layout_algorithm import (
|
||||
Layout,
|
||||
broadcast,
|
||||
canonicalization,
|
||||
@@ -44,7 +44,7 @@ from cutlass.backend.evt.ir.layout_algorithm import (
|
||||
reshape,
|
||||
_reverse_tuple
|
||||
)
|
||||
from cutlass.utils.datatypes import get_datatype_and_layout, get_tensor_shape, library_type
|
||||
from cutlass_cppgen.utils.datatypes import get_datatype_and_layout, get_tensor_shape, library_type
|
||||
|
||||
|
||||
class Tensor:
|
||||
|
||||
@@ -30,13 +30,13 @@
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from cutlass.backend.evt.passes.graph_drawer import EVTGraphDrawer
|
||||
from cutlass.backend.evt.passes.pass_argument_type import PassGetArgumentType
|
||||
from cutlass.backend.evt.passes.pass_dag_2_tree import PassDAG2Tree
|
||||
from cutlass.backend.evt.passes.pass_get_impl import PassGetImpl
|
||||
from cutlass.backend.evt.passes.pass_fix_element_d import PassFixElementD
|
||||
from cutlass.backend.evt.passes.pass_layout_elimination import PassLayoutManipulateElimination
|
||||
from cutlass.backend.evt.passes.pass_manager import EVTPassManager
|
||||
from cutlass.backend.evt.passes.pass_preprocess_red import PassPreprocessRed
|
||||
from cutlass.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
|
||||
from cutlass.backend.evt.passes.smem_size_calculator import GetSmemSize
|
||||
from cutlass_cppgen.backend.evt.passes.graph_drawer import EVTGraphDrawer
|
||||
from cutlass_cppgen.backend.evt.passes.pass_argument_type import PassGetArgumentType
|
||||
from cutlass_cppgen.backend.evt.passes.pass_dag_2_tree import PassDAG2Tree
|
||||
from cutlass_cppgen.backend.evt.passes.pass_get_impl import PassGetImpl
|
||||
from cutlass_cppgen.backend.evt.passes.pass_fix_element_d import PassFixElementD
|
||||
from cutlass_cppgen.backend.evt.passes.pass_layout_elimination import PassLayoutManipulateElimination
|
||||
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassManager
|
||||
from cutlass_cppgen.backend.evt.passes.pass_preprocess_red import PassPreprocessRed
|
||||
from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
|
||||
from cutlass_cppgen.backend.evt.passes.smem_size_calculator import GetSmemSize
|
||||
|
||||
@@ -35,7 +35,7 @@ import subprocess
|
||||
|
||||
from cutlass_library import DataTypeTag
|
||||
|
||||
from cutlass.backend.evt.ir.dag_ir import DAGIR
|
||||
from cutlass_cppgen.backend.evt.ir.dag_ir import DAGIR
|
||||
|
||||
|
||||
_COLOR_MAP = {
|
||||
|
||||
@@ -34,12 +34,12 @@
|
||||
Construct the epilogue visitor argument type
|
||||
"""
|
||||
|
||||
from cutlass.backend.c_types import visitor_factory
|
||||
from cutlass.backend.evt.ir import TopoVisitorNode
|
||||
from cutlass.backend.evt.passes.pass_dag_2_tree import PassDAG2Tree
|
||||
from cutlass.backend.evt.passes.pass_get_impl import PassGetImpl
|
||||
from cutlass.backend.evt.passes.pass_manager import EVTPassBase
|
||||
from cutlass.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
|
||||
from cutlass_cppgen.backend.c_types import visitor_factory
|
||||
from cutlass_cppgen.backend.evt.ir import TopoVisitorNode
|
||||
from cutlass_cppgen.backend.evt.passes.pass_dag_2_tree import PassDAG2Tree
|
||||
from cutlass_cppgen.backend.evt.passes.pass_get_impl import PassGetImpl
|
||||
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
|
||||
from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
|
||||
|
||||
|
||||
class PassGetArgumentType(EVTPassBase):
|
||||
|
||||
@@ -37,10 +37,10 @@ by the topological visitor, while the rest of the graph will be implemented with
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
from cutlass.backend.evt.ir import DAGIR, TopoVisitorNode
|
||||
from cutlass.backend.evt.passes.pass_get_impl import PassGetImpl
|
||||
from cutlass.backend.evt.passes.pass_manager import EVTPassBase
|
||||
from cutlass.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
|
||||
from cutlass_cppgen.backend.evt.ir import DAGIR, TopoVisitorNode
|
||||
from cutlass_cppgen.backend.evt.passes.pass_get_impl import PassGetImpl
|
||||
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
|
||||
from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
|
||||
|
||||
|
||||
class PassDAG2Tree(EVTPassBase):
|
||||
|
||||
@@ -37,8 +37,8 @@ In Sm90 epilogue visitor, the node writing D to gmem does not have internal
|
||||
element converter, so the compute node producing D must have element_output = type(D).
|
||||
"""
|
||||
|
||||
from cutlass.backend.evt.passes.pass_layout_elimination import PassLayoutManipulateElimination
|
||||
from cutlass.backend.evt.passes.pass_manager import EVTPassBase
|
||||
from cutlass_cppgen.backend.evt.passes.pass_layout_elimination import PassLayoutManipulateElimination
|
||||
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
|
||||
|
||||
|
||||
class PassFixElementD(EVTPassBase):
|
||||
|
||||
@@ -39,13 +39,13 @@ on their layout. For instance, a LoadNode can be AuxLoad, Row/Col/Scalar broadca
|
||||
This pass infers the underlying impl of each node
|
||||
"""
|
||||
|
||||
import cutlass.backend.evt.backend as evt_backend
|
||||
from cutlass.backend.evt.ir import DAGIR, LoadNode
|
||||
from cutlass.backend.evt.passes.pass_fix_element_d import PassFixElementD
|
||||
from cutlass.backend.evt.passes.pass_manager import EVTPassBase
|
||||
from cutlass.backend.evt.passes.pass_no_op_elimination import PassNoOpElimination
|
||||
from cutlass.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
|
||||
from cutlass.backend.evt.passes.util import cc_map
|
||||
import cutlass_cppgen.backend.evt.backend as evt_backend
|
||||
from cutlass_cppgen.backend.evt.ir import DAGIR, LoadNode
|
||||
from cutlass_cppgen.backend.evt.passes.pass_fix_element_d import PassFixElementD
|
||||
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
|
||||
from cutlass_cppgen.backend.evt.passes.pass_no_op_elimination import PassNoOpElimination
|
||||
from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
|
||||
from cutlass_cppgen.backend.evt.passes.util import cc_map
|
||||
|
||||
|
||||
class PassGetImpl(EVTPassBase):
|
||||
|
||||
@@ -36,9 +36,9 @@ Eliminate layout manipulation nodes
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
from cutlass.backend.evt.ir import DAGIR, LayoutNode
|
||||
from cutlass.backend.evt.passes.pass_manager import EVTPassBase
|
||||
from cutlass.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
|
||||
from cutlass_cppgen.backend.evt.ir import DAGIR, LayoutNode
|
||||
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
|
||||
from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
|
||||
|
||||
|
||||
class PassLayoutManipulateElimination(EVTPassBase):
|
||||
|
||||
@@ -38,8 +38,8 @@ from typing import Any
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from cutlass.backend.evt.ir import DAGIR
|
||||
from cutlass.backend.evt.passes.util import cc_map
|
||||
from cutlass_cppgen.backend.evt.ir import DAGIR
|
||||
from cutlass_cppgen.backend.evt.passes.util import cc_map
|
||||
|
||||
|
||||
class EVTPassBase:
|
||||
|
||||
@@ -36,8 +36,8 @@ No op elimination node
|
||||
|
||||
from typing import Any
|
||||
|
||||
from cutlass.backend.evt.ir import NoOpImpl
|
||||
from cutlass.backend.evt.passes.pass_manager import EVTPassBase
|
||||
from cutlass_cppgen.backend.evt.ir import NoOpImpl
|
||||
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
|
||||
|
||||
|
||||
class PassNoOpElimination(EVTPassBase):
|
||||
|
||||
@@ -38,8 +38,8 @@ This pass fuses these into a single store node, and then replaces all uses of th
|
||||
current node with the new store node.
|
||||
"""
|
||||
|
||||
from cutlass.backend.evt.ir import ComputeNode, StoreNode
|
||||
from cutlass.backend.evt.passes.pass_manager import EVTPassBase
|
||||
from cutlass_cppgen.backend.evt.ir import ComputeNode, StoreNode
|
||||
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
|
||||
|
||||
|
||||
class PassPreprocessRed(EVTPassBase):
|
||||
|
||||
@@ -34,9 +34,9 @@
|
||||
Shape and type propagation pass
|
||||
"""
|
||||
|
||||
from cutlass.backend.evt.ir.node import NodeBase
|
||||
from cutlass.backend.evt.passes.pass_manager import EVTPassBase
|
||||
from cutlass.backend.evt.passes.pass_preprocess_red import PassPreprocessRed
|
||||
from cutlass_cppgen.backend.evt.ir.node import NodeBase
|
||||
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
|
||||
from cutlass_cppgen.backend.evt.passes.pass_preprocess_red import PassPreprocessRed
|
||||
|
||||
|
||||
class PassShapeTypePropagation(EVTPassBase):
|
||||
|
||||
@@ -37,9 +37,9 @@ Compute the shared memory size in bytes
|
||||
import cutlass_library
|
||||
from pycute import shape_div, product
|
||||
|
||||
import cutlass
|
||||
from cutlass.backend.evt.ir import TopoVisitorNode, DAGIR
|
||||
from cutlass.backend.library import DataTypeSize
|
||||
import cutlass_cppgen
|
||||
from cutlass_cppgen.backend.evt.ir import TopoVisitorNode, DAGIR
|
||||
from cutlass_cppgen.backend.library import DataTypeSize
|
||||
|
||||
|
||||
class GetSmemSize:
|
||||
|
||||
@@ -31,12 +31,12 @@
|
||||
#################################################################################################
|
||||
from __future__ import annotations
|
||||
|
||||
from cutlass.utils.lazy_import import lazy_import
|
||||
from cutlass_cppgen.utils.lazy_import import lazy_import
|
||||
cuda = lazy_import("cuda.cuda")
|
||||
import numpy as np
|
||||
|
||||
from cutlass.backend.memory_manager import device_mem_alloc, todevice
|
||||
from cutlass.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_tensor
|
||||
from cutlass_cppgen.backend.memory_manager import device_mem_alloc, todevice
|
||||
from cutlass_cppgen.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_tensor
|
||||
|
||||
|
||||
class NumpyFrontend:
|
||||
|
||||
@@ -35,7 +35,7 @@ import copy
|
||||
import ctypes
|
||||
import enum
|
||||
|
||||
from cutlass.utils.lazy_import import lazy_import
|
||||
from cutlass_cppgen.utils.lazy_import import lazy_import
|
||||
cuda = lazy_import("cuda.cuda")
|
||||
cudart = lazy_import("cuda.cudart")
|
||||
from cutlass_library import SubstituteTemplate
|
||||
@@ -74,8 +74,8 @@ from cutlass_library import (
|
||||
TileSchedulerType,
|
||||
get_complex_from_real
|
||||
)
|
||||
from cutlass.backend.arguments import ArgumentBase
|
||||
from cutlass.backend.c_types import (
|
||||
from cutlass_cppgen.backend.arguments import ArgumentBase
|
||||
from cutlass_cppgen.backend.c_types import (
|
||||
GemmCoord_,
|
||||
GemmCoordBatched_,
|
||||
GenericMainloopArguments3x_,
|
||||
@@ -88,7 +88,7 @@ from cutlass.backend.c_types import (
|
||||
get_mainloop_arguments_3x,
|
||||
get_tile_scheduler_arguments_3x,
|
||||
)
|
||||
from cutlass.backend.library import (
|
||||
from cutlass_cppgen.backend.library import (
|
||||
ApiVersion,
|
||||
EmissionType,
|
||||
SchedulerMode,
|
||||
@@ -97,11 +97,11 @@ from cutlass.backend.library import (
|
||||
TileDescription,
|
||||
api_version,
|
||||
)
|
||||
from cutlass.backend.memory_manager import device_mem_alloc, todevice
|
||||
from cutlass.backend.operation import ExecutableOperation, LaunchConfiguration
|
||||
from cutlass.backend.type_hint import GemmOperation, Tensor
|
||||
from cutlass.backend.utils.device import device_sm_count
|
||||
from cutlass.shape import GemmCoord, MatrixCoord
|
||||
from cutlass_cppgen.backend.memory_manager import device_mem_alloc, todevice
|
||||
from cutlass_cppgen.backend.operation import ExecutableOperation, LaunchConfiguration
|
||||
from cutlass_cppgen.backend.type_hint import GemmOperation, Tensor
|
||||
from cutlass_cppgen.backend.utils.device import device_sm_count
|
||||
from cutlass_cppgen.shape import GemmCoord, MatrixCoord
|
||||
|
||||
|
||||
################################################################################
|
||||
@@ -116,9 +116,9 @@ def leading_dimension(layout: LayoutType, shape: MatrixCoord) -> int:
|
||||
Returns the leading dimenson of a tensor with layout ``layout`` and shape ``shape``.
|
||||
|
||||
:param layout: layout of the tensor
|
||||
:type layout: cutlass.shape.LayoutType
|
||||
:type layout: cutlass_cppgen.shape.LayoutType
|
||||
:param shape: shape of the tensor
|
||||
:type shape: cutlass.shape.MatrixCoord
|
||||
:type shape: cutlass_cppgen.shape.MatrixCoord
|
||||
|
||||
:return: leading dimension of the tensor
|
||||
:rtype: int
|
||||
@@ -144,11 +144,11 @@ class GemmArguments2x(ArgumentBase):
|
||||
user-provide tensors into the kernel's argument
|
||||
|
||||
:param operation: the GEMM operation to take the argument
|
||||
:type operation: :class:`cutlass.backend.GemmOperationUniversal` |
|
||||
:class:`cutlass.backend.GemmOperationGrouped`
|
||||
:type operation: :class:`cutlass_cppgen.backend.GemmOperationUniversal` |
|
||||
:class:`cutlass_cppgen.backend.GemmOperationGrouped`
|
||||
|
||||
:param problem_size: GEMM problem size gemm(M, N, K)
|
||||
:type operation: :class:`cutlass.shape.GemmCoord`
|
||||
:type operation: :class:`cutlass_cppgen.shape.GemmCoord`
|
||||
|
||||
:param A: tensor A
|
||||
:type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
|
||||
@@ -166,7 +166,7 @@ class GemmArguments2x(ArgumentBase):
|
||||
:type gemm_mode: :class:`cutlass_library.GemmUniversalMode`
|
||||
|
||||
:param output_op: output operator, optional
|
||||
:type output_op: :class:`cutlass.backend.LinearCombinationFunctorArguments`
|
||||
:type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments`
|
||||
|
||||
:param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
|
||||
:type stream: :class:`cuda.cuda.CUstream`
|
||||
@@ -371,11 +371,11 @@ class GemmArguments2xStreamK(GemmArguments2x):
|
||||
user-provide tensors into the kernel's argument
|
||||
|
||||
:param operation: the GEMM operation to take the argument
|
||||
:type operation: :class:`cutlass.backend.GemmOperationUniversal` |
|
||||
:class:`cutlass.backend.GemmOperationGrouped`
|
||||
:type operation: :class:`cutlass_cppgen.backend.GemmOperationUniversal` |
|
||||
:class:`cutlass_cppgen.backend.GemmOperationGrouped`
|
||||
|
||||
:param problem_size: GEMM problem size gemm(M, N, K)
|
||||
:type operation: :class:`cutlass.shape.GemmCoord`
|
||||
:type operation: :class:`cutlass_cppgen.shape.GemmCoord`
|
||||
|
||||
:param A: tensor A
|
||||
:type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
|
||||
@@ -393,7 +393,7 @@ class GemmArguments2xStreamK(GemmArguments2x):
|
||||
:type gemm_mode: :class:`cutlass_library.GemmUniversalMode`
|
||||
|
||||
:param output_op: output operator, optional
|
||||
:type output_op: :class:`cutlass.backend.LinearCombinationFunctorArguments`
|
||||
:type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments`
|
||||
"""
|
||||
|
||||
def __init__(self, operation, problem_size, A, B, C, D, gemm_mode=GemmUniversalMode.Gemm, **kwargs):
|
||||
@@ -483,11 +483,11 @@ class GemmArguments3x(GemmArguments2x):
|
||||
user-provide tensors into the kernel's argument
|
||||
|
||||
:param operation: the GEMM operation to take the argument
|
||||
:type operation: :class:`cutlass.backend.GemmOperationUniversal` |
|
||||
:class:`cutlass.backend.GemmOperationGrouped`
|
||||
:type operation: :class:`cutlass_cppgen.backend.GemmOperationUniversal` |
|
||||
:class:`cutlass_cppgen.backend.GemmOperationGrouped`
|
||||
|
||||
:param problem_size: GEMM problem size gemm(M, N, K)
|
||||
:type operation: :class:`cutlass.shape.GemmCoord`
|
||||
:type operation: :class:`cutlass_cppgen.shape.GemmCoord`
|
||||
|
||||
:param A: tensor A
|
||||
:type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
|
||||
@@ -505,7 +505,7 @@ class GemmArguments3x(GemmArguments2x):
|
||||
:type gemm_mode: GemmUniversalMode
|
||||
|
||||
:param output_op: output operator, optional
|
||||
:type output_op: :class:`cutlass.backend.LinearCombinationFunctorArguments`
|
||||
:type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments`
|
||||
"""
|
||||
|
||||
def __init__(self, operation, problem_size, A, B, C, D, gemm_mode=GemmUniversalMode.Gemm, **kwargs):
|
||||
@@ -631,11 +631,11 @@ def GemmArguments(operation, problem_size, A, B, C, D, gemm_mode=GemmUniversalMo
|
||||
or 3x arguments depending on the `arch` field specified in `operation`.
|
||||
|
||||
:param operation: the GEMM operation to take the argument
|
||||
:type operation: :class:`cutlass.backend.GemmOperationUniversal` |
|
||||
:class:`cutlass.backend.GemmOperationGrouped`
|
||||
:type operation: :class:`cutlass_cppgen.backend.GemmOperationUniversal` |
|
||||
:class:`cutlass_cppgen.backend.GemmOperationGrouped`
|
||||
|
||||
:param problem_size: GEMM problem size gemm(M, N, K)
|
||||
:type operation: :class:`cutlass.shape.GemmCoord`
|
||||
:type operation: :class:`cutlass_cppgen.shape.GemmCoord`
|
||||
|
||||
:param A: tensor A
|
||||
:type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
|
||||
@@ -653,7 +653,7 @@ def GemmArguments(operation, problem_size, A, B, C, D, gemm_mode=GemmUniversalMo
|
||||
:type gemm_mode: :class:`cutlass_library.GemmUniversalMode`
|
||||
|
||||
:param output_op: output operator, optional
|
||||
:type output_op: :class:`cutlass.backend.LinearCombinationFunctorArguments`
|
||||
:type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments`
|
||||
"""
|
||||
if operation.swizzling_functor == SwizzlingFunctor.StreamK:
|
||||
if operation.api == ApiVersion.v3x:
|
||||
@@ -670,10 +670,10 @@ class GemmGroupedArguments:
|
||||
user-provide tensors into the kernel's argument
|
||||
|
||||
:param operation: the GEMM Grouped operation to take the argument
|
||||
:type operation: :class:`cutlass.backend.GemmOperationGrouped`
|
||||
:type operation: :class:`cutlass_cppgen.backend.GemmOperationGrouped`
|
||||
|
||||
:param problem_size: list of GEMM problem size gemm(M, N, K)
|
||||
:type operation: list[:class:`cutlass.shape.GemmCoord`]
|
||||
:type operation: list[:class:`cutlass_cppgen.shape.GemmCoord`]
|
||||
|
||||
:param A: list of tensor A
|
||||
:type A: list[cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray]
|
||||
@@ -688,7 +688,7 @@ class GemmGroupedArguments:
|
||||
:type D: list[cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray]
|
||||
|
||||
:param output_op: output operator, optional
|
||||
:type output_op: :class:`cutlass.backend.LinearCombinationFunctorArguments`
|
||||
:type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments`
|
||||
|
||||
:param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
|
||||
:type stream: :class:`cuda.cuda.CUstream`
|
||||
|
||||
@@ -417,7 +417,7 @@ def CalculateSmemUsagePerStage(operation):
|
||||
:param op: operation for which the maximum stages should be computed. If stages are
|
||||
set via the `op.tile_description.stages` parameter, this setting is ignored
|
||||
in the present calculation
|
||||
:type op: cutlass.backend.Operation
|
||||
:type op: cutlass_cppgen.backend.Operation
|
||||
|
||||
:return: number of bytes of shared memory consumed by a single stage
|
||||
:rtype: int
|
||||
@@ -442,7 +442,7 @@ def CalculateSmemUsage(operation):
|
||||
:param op: operation for which the maximum stages should be computed. If stages are
|
||||
set via the `op.tile_description.stages` parameter, this setting is ignored
|
||||
in the present calculation
|
||||
:type op: cutlass.backend.Operation
|
||||
:type op: cutlass_cppgen.backend.Operation
|
||||
|
||||
:return: int
|
||||
"""
|
||||
|
||||
@@ -32,11 +32,11 @@
|
||||
|
||||
import numpy as np
|
||||
|
||||
import cutlass
|
||||
from cutlass.utils.datatypes import is_numpy_tensor
|
||||
from cutlass.utils.lazy_import import lazy_import
|
||||
import cutlass_cppgen
|
||||
from cutlass_cppgen.utils.datatypes import is_numpy_tensor
|
||||
from cutlass_cppgen.utils.lazy_import import lazy_import
|
||||
|
||||
if cutlass.use_rmm:
|
||||
if cutlass_cppgen.use_rmm:
|
||||
import rmm
|
||||
else:
|
||||
cudart = lazy_import("cuda.cudart")
|
||||
@@ -73,7 +73,7 @@ def _todevice(host_data):
|
||||
"""
|
||||
Helper for transferring host data to device memory
|
||||
"""
|
||||
if cutlass.use_rmm:
|
||||
if cutlass_cppgen.use_rmm:
|
||||
return rmm.DeviceBuffer.to_device(host_data.tobytes())
|
||||
else:
|
||||
nbytes = len(host_data.tobytes())
|
||||
@@ -100,7 +100,7 @@ def todevice(host_data, dtype=np.float32):
|
||||
|
||||
|
||||
def device_mem_alloc(size):
|
||||
if cutlass.use_rmm:
|
||||
if cutlass_cppgen.use_rmm:
|
||||
return rmm.DeviceBuffer(size=size)
|
||||
else:
|
||||
err, ptr = cudart.cudaMalloc(size)
|
||||
@@ -114,7 +114,7 @@ def align_size(size, alignment=256):
|
||||
|
||||
|
||||
def create_memory_pool(init_pool_size=0, max_pool_size=2 ** 34):
|
||||
if cutlass.use_rmm:
|
||||
if cutlass_cppgen.use_rmm:
|
||||
memory_pool = PoolMemoryManager(init_pool_size=init_pool_size, max_pool_size=max_pool_size)
|
||||
return memory_pool
|
||||
else:
|
||||
|
||||
@@ -31,10 +31,10 @@
|
||||
#################################################################################################
|
||||
|
||||
import ctypes
|
||||
from cutlass.utils.lazy_import import lazy_import
|
||||
from cutlass_cppgen.utils.lazy_import import lazy_import
|
||||
cuda = lazy_import("cuda.cuda")
|
||||
|
||||
from cutlass.backend.utils.device import device_cc
|
||||
from cutlass_cppgen.backend.utils.device import device_cc
|
||||
|
||||
_supports_cluster_launch = None
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ from __future__ import annotations
|
||||
import ctypes
|
||||
from typing import Union
|
||||
|
||||
from cutlass.utils.lazy_import import lazy_import
|
||||
from cutlass_cppgen.utils.lazy_import import lazy_import
|
||||
cuda = lazy_import("cuda.cuda")
|
||||
cudart = lazy_import("cuda.cudart")
|
||||
import numpy as np
|
||||
@@ -47,14 +47,14 @@ from cutlass_library import (
|
||||
SubstituteTemplate
|
||||
)
|
||||
|
||||
import cutlass
|
||||
from cutlass.backend.c_types import MatrixCoord_, TensorRef2D_, get_reduction_params
|
||||
from cutlass.backend.frontend import NumpyFrontend, TorchFrontend
|
||||
from cutlass.backend.library import TensorDescription
|
||||
from cutlass.backend.memory_manager import DevicePtrWrapper
|
||||
from cutlass.backend.operation import ExecutableOperation, LaunchConfiguration
|
||||
from cutlass.shape import MatrixCoord
|
||||
from cutlass.utils.datatypes import is_numpy_tensor, is_torch_tensor
|
||||
import cutlass_cppgen
|
||||
from cutlass_cppgen.backend.c_types import MatrixCoord_, TensorRef2D_, get_reduction_params
|
||||
from cutlass_cppgen.backend.frontend import NumpyFrontend, TorchFrontend
|
||||
from cutlass_cppgen.backend.library import TensorDescription
|
||||
from cutlass_cppgen.backend.memory_manager import DevicePtrWrapper
|
||||
from cutlass_cppgen.backend.operation import ExecutableOperation, LaunchConfiguration
|
||||
from cutlass_cppgen.shape import MatrixCoord
|
||||
from cutlass_cppgen.utils.datatypes import is_numpy_tensor, is_torch_tensor
|
||||
|
||||
|
||||
class ReductionOperation:
|
||||
@@ -200,7 +200,7 @@ class ReductionArguments:
|
||||
Frees allocated device-side memory
|
||||
"""
|
||||
# Free any device memory allocated manually
|
||||
if not cutlass.use_rmm:
|
||||
if not cutlass_cppgen.use_rmm:
|
||||
for attr in ["destination_buffer", "source_buffer"]:
|
||||
if hasattr(self, attr):
|
||||
buf = getattr(self, attr)
|
||||
|
||||
@@ -30,4 +30,4 @@
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from cutlass.backend.utils.device import check_cuda_errors, device_cc
|
||||
from cutlass_cppgen.backend.utils.device import check_cuda_errors, device_cc
|
||||
|
||||
@@ -35,12 +35,12 @@ Utility functions for interacting with the device
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from cutlass.utils.lazy_import import lazy_import
|
||||
from cutlass_cppgen.utils.lazy_import import lazy_import
|
||||
cuda = lazy_import("cuda.cuda")
|
||||
cudart = lazy_import("cuda.cudart")
|
||||
|
||||
import cutlass
|
||||
from cutlass.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_tensor
|
||||
import cutlass_cppgen
|
||||
from cutlass_cppgen.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_tensor
|
||||
|
||||
|
||||
def check_cuda_errors(result: list):
|
||||
@@ -77,7 +77,7 @@ def device_cc(device: int = -1) -> int:
|
||||
:rtype: int
|
||||
"""
|
||||
if device == -1:
|
||||
device = cutlass.device_id()
|
||||
device = cutlass_cppgen.device_id()
|
||||
|
||||
deviceProp = check_cuda_errors(cudart.cudaGetDeviceProperties(device))
|
||||
major = str(deviceProp.major)
|
||||
@@ -87,7 +87,7 @@ def device_cc(device: int = -1) -> int:
|
||||
|
||||
def device_sm_count(device: int = -1):
|
||||
if device == -1:
|
||||
device = cutlass.device_id()
|
||||
device = cutlass_cppgen.device_id()
|
||||
err, device_sm_count = cuda.cuDeviceGetAttribute(
|
||||
cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device
|
||||
)
|
||||
|
||||
@@ -30,4 +30,4 @@
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from cutlass.emit.pytorch import pytorch
|
||||
from cutlass_cppgen.emit.pytorch import pytorch
|
||||
|
||||
@@ -34,10 +34,10 @@
|
||||
Common utilities for emitting CUTLASS kernels
|
||||
"""
|
||||
|
||||
import cutlass
|
||||
import cutlass_cppgen
|
||||
|
||||
# Strings used for printing information about the generation of emitted scripts
|
||||
_AUTOGEN_STR = f"This file was automatically generated by the CUTLASS {cutlass.__version__} Python interface (https://github.com/nvidia/cutlass/python)"
|
||||
_AUTOGEN_STR = f"This file was automatically generated by the CUTLASS {cutlass_cppgen.__version__} Python interface (https://github.com/nvidia/cutlass/python)"
|
||||
|
||||
|
||||
_CSTYLE_AUTOGEN_COMMENT = f"""// {_AUTOGEN_STR}
|
||||
|
||||
@@ -39,9 +39,9 @@ Example usage with JIT compilation:
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
plan = cutlass.op.Gemm(element=torch.float32, layout=cutlass_library.LayoutType.RowMajor)
|
||||
plan = cutlass_cppgen.op.Gemm(element=torch.float32, layout=cutlass_library.LayoutType.RowMajor)
|
||||
op = plan.construct()
|
||||
mod = cutlass.emit.pytorch(op, 'cutlass_gemm', 80, jit=True)
|
||||
mod = cutlass_cppgen.emit.pytorch(op, 'cutlass_gemm', 80, jit=True)
|
||||
|
||||
# Generate inputs for the GEMM
|
||||
A, B, C = [torch.ones((512, 512)).to('cuda') for _ in range(3)]
|
||||
@@ -55,9 +55,9 @@ Example usage without JIT compilation:
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
plan = cutlass.op.Gemm(element=torch.float32, layout=cutlass.LayoutType.RowMajor)
|
||||
plan = cutlass_cppgen.op.Gemm(element=torch.float32, layout=cutlass_cppgen.LayoutType.RowMajor)
|
||||
op = plan.construct()
|
||||
cutlass.emit.pytorch(op, 'cutlass_gemm', 80, jit=False, sourcedir='output')
|
||||
cutlass_cppgen.emit.pytorch(op, 'cutlass_gemm', 80, jit=False, sourcedir='output')
|
||||
|
||||
After this call, the directory ``output`` contains ``setup.py``,
|
||||
``cutlass_gemm.cpp``, and ``cutlass_gemm_kernel.cu``. The module can be built from
|
||||
@@ -83,12 +83,12 @@ import os
|
||||
|
||||
from cutlass_library import ConvKind, ConvKindNames, DataType, SubstituteTemplate
|
||||
|
||||
from cutlass import CUTLASS_PATH, logger, swizzle
|
||||
from cutlass.backend.gemm_operation import GemmOperationGrouped, GemmOperationUniversal
|
||||
from cutlass.backend.conv2d_operation import Conv2dOperation
|
||||
from cutlass.backend.library import ApiVersion
|
||||
from cutlass.emit import common
|
||||
from cutlass.utils.datatypes import is_torch_available
|
||||
from cutlass_cppgen import CUTLASS_PATH, logger, swizzle
|
||||
from cutlass_cppgen.backend.gemm_operation import GemmOperationGrouped, GemmOperationUniversal
|
||||
from cutlass_cppgen.backend.conv2d_operation import Conv2dOperation
|
||||
from cutlass_cppgen.backend.library import ApiVersion
|
||||
from cutlass_cppgen.emit import common
|
||||
from cutlass_cppgen.utils.datatypes import is_torch_available
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
@@ -30,7 +30,7 @@
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from cutlass.epilogue.epilogue import (
|
||||
from cutlass_cppgen.epilogue.epilogue import (
|
||||
get_activations,
|
||||
get_activation_epilogue,
|
||||
gelu,
|
||||
@@ -44,7 +44,7 @@ from cutlass.epilogue.epilogue import (
|
||||
trace
|
||||
)
|
||||
|
||||
from cutlass.epilogue.evt_ops import (
|
||||
from cutlass_cppgen.epilogue.evt_ops import (
|
||||
max,
|
||||
multiply_add,
|
||||
sum,
|
||||
|
||||
@@ -39,11 +39,11 @@ code like the following for GEMM:
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
plan = cutlass.op.Gemm(element=cutlass.DataType.f32, layout=cutlass.LayoutType.RowMajor)
|
||||
plan.activation = cutlass.epilogue.relu
|
||||
plan = cutlass_cppgen.op.Gemm(element=cutlass_cppgen.DataType.f32, layout=cutlass_cppgen.LayoutType.RowMajor)
|
||||
plan.activation = cutlass_cppgen.epilogue.relu
|
||||
"""
|
||||
|
||||
from cutlass.backend import epilogue, device_cc
|
||||
from cutlass_cppgen.backend import epilogue, device_cc
|
||||
|
||||
|
||||
gelu = epilogue.gelu
|
||||
@@ -111,7 +111,7 @@ def get_activation_epilogue(
|
||||
"""
|
||||
Frontend for EVT that generates epilogue functor through tracing the input function
|
||||
"""
|
||||
from cutlass.backend.evt.frontend import PythonASTFrontend
|
||||
from cutlass_cppgen.backend.evt.frontend import PythonASTFrontend
|
||||
|
||||
|
||||
def trace(fn, example_tensors, **kwargs):
|
||||
@@ -124,7 +124,7 @@ def trace(fn, example_tensors, **kwargs):
|
||||
|
||||
.. hightlight:: python
|
||||
.. code-block:: python
|
||||
import cutlass.backend.evt
|
||||
import cutlass_cppgen.backend.evt
|
||||
|
||||
# Define epilogue function as Python callable
|
||||
def example_fn(accum, C, alpha, beta, gamma):
|
||||
@@ -142,7 +142,7 @@ def trace(fn, example_tensors, **kwargs):
|
||||
}
|
||||
|
||||
# Generate the epilogue functor
|
||||
epilogue_visitor = cutlass.epilogue.trace(example_fn, example_inputs)
|
||||
epilogue_visitor = cutlass_cppgen.epilogue.trace(example_fn, example_inputs)
|
||||
"""
|
||||
if callable(fn):
|
||||
class EpilogueFunctor(PythonASTFrontend):
|
||||
|
||||
@@ -36,7 +36,7 @@ Collection of builtin functions used for host reference in EVT
|
||||
|
||||
import numpy as np
|
||||
|
||||
from cutlass.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_available, is_torch_tensor
|
||||
from cutlass_cppgen.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_available, is_torch_tensor
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
@@ -40,9 +40,9 @@ import logging
|
||||
import cutlass_library
|
||||
from cutlass_library.library import ConvKind, IteratorAlgorithm, StrideSupport, GroupMode
|
||||
|
||||
import cutlass
|
||||
from cutlass.utils.check import valid_stage_count
|
||||
from cutlass.utils.datatypes import td_from_profiler_td, td_from_profiler_op
|
||||
import cutlass_cppgen
|
||||
from cutlass_cppgen.utils.check import valid_stage_count
|
||||
from cutlass_cppgen.utils.datatypes import td_from_profiler_td, td_from_profiler_op
|
||||
|
||||
|
||||
_generator_ccs = [50, 60, 61, 70, 75, 80, 90]
|
||||
@@ -99,14 +99,14 @@ class KernelsForDataType:
|
||||
ops.extend(alignment_ops)
|
||||
return ops
|
||||
|
||||
def default_operation(self, math_operation: cutlass.MathOperation):
|
||||
def default_operation(self, math_operation: cutlass_cppgen.MathOperation):
|
||||
key = sorted(list(self.kernels_by_alignment.keys()))[0]
|
||||
kernels = self.kernels_by_alignment[key]
|
||||
if math_operation is not None:
|
||||
kernels = [x for x in kernels if x.tile_description.math_instruction.math_operation == math_operation]
|
||||
return kernels[0]
|
||||
|
||||
def operations(self, alignment_A: int, alignment_B: int, alignment_C: int, math_operation: cutlass.MathOperation):
|
||||
def operations(self, alignment_A: int, alignment_B: int, alignment_C: int, math_operation: cutlass_cppgen.MathOperation):
|
||||
"""
|
||||
Returns operations satisfying the alignment constraints
|
||||
|
||||
@@ -117,7 +117,7 @@ class KernelsForDataType:
|
||||
:param alignment_C: alignment constraint of operations to return
|
||||
:type alignment_C: int
|
||||
:param math_operation: math operation to consider
|
||||
:type math_operation: cutlass.MathOperation
|
||||
:type math_operation: cutlass_cppgen.MathOperation
|
||||
|
||||
:return: list of operations
|
||||
:rtype: list
|
||||
@@ -158,14 +158,14 @@ class KernelsForDataType:
|
||||
|
||||
return operand_list.index(key)
|
||||
|
||||
def find_alignment(self, shape: tuple, layout: cutlass.LayoutType, operand=str) -> int:
|
||||
def find_alignment(self, shape: tuple, layout: cutlass_cppgen.LayoutType, operand=str) -> int:
|
||||
"""
|
||||
Returns the most preferable alignment for a given shape and layout
|
||||
|
||||
:param shape: extent of each dimension of the tensor
|
||||
:type shape: tuple
|
||||
:param layout: layout of the tensor
|
||||
:type layout: cutlass.LayoutType
|
||||
:type layout: cutlass_cppgen.LayoutType
|
||||
:param operand: descriptor of the operand in question
|
||||
:type operand: str
|
||||
|
||||
@@ -175,11 +175,11 @@ class KernelsForDataType:
|
||||
operand_idx = self._operand_idx(operand)
|
||||
|
||||
# Determine the leading dimension of the shape
|
||||
if layout == cutlass.LayoutType.ColumnMajor:
|
||||
if layout == cutlass_cppgen.LayoutType.ColumnMajor:
|
||||
ld = shape[-2]
|
||||
elif layout == cutlass.LayoutType.RowMajor:
|
||||
elif layout == cutlass_cppgen.LayoutType.RowMajor:
|
||||
ld = shape[-1]
|
||||
elif layout == cutlass.LayoutType.TensorNHWC:
|
||||
elif layout == cutlass_cppgen.LayoutType.TensorNHWC:
|
||||
ld = shape[-1]
|
||||
else:
|
||||
raise Exception(f"Unexpected or unsupported layout {layout}")
|
||||
@@ -204,12 +204,12 @@ class KernelsForDataType:
|
||||
for alignment in self.kernels_by_alignment.keys():
|
||||
self.kernels_by_alignment[alignment].sort(key=key, reverse=True)
|
||||
|
||||
def supports_math_operation(self, math_operation: cutlass.MathOperation) -> bool:
|
||||
def supports_math_operation(self, math_operation: cutlass_cppgen.MathOperation) -> bool:
|
||||
"""
|
||||
Returns whether `math_operation` is supported by at least one operation.
|
||||
|
||||
:param math_operation: math operation to consider
|
||||
:type math_operation: cutlass.MathOperation
|
||||
:type math_operation: cutlass_cppgen.MathOperation
|
||||
|
||||
:return: whether math_operation is supported by at least one operation
|
||||
:rtype: bool
|
||||
@@ -262,7 +262,7 @@ class ArchOptions:
|
||||
# descriptions for the target CC
|
||||
generate_function_name = "GenerateSM" + str(kernel_cc)
|
||||
if not hasattr(cutlass_library.generator, generate_function_name):
|
||||
cutlass.logger.warning(f"No generator found for architecture {kernel_cc}")
|
||||
cutlass_cppgen.logger.warning(f"No generator found for architecture {kernel_cc}")
|
||||
return
|
||||
generate_function = getattr(cutlass_library.generator, generate_function_name)
|
||||
|
||||
@@ -270,16 +270,16 @@ class ArchOptions:
|
||||
# for the target CC
|
||||
args = [
|
||||
"--kernels=all",
|
||||
f"--log-level={logging.getLevelName(cutlass.logger.level)}"
|
||||
f"--log-level={logging.getLevelName(cutlass_cppgen.logger.level)}"
|
||||
]
|
||||
manifest_args = cutlass_library.generator.define_parser().parse_args(args)
|
||||
manifest = cutlass_library.manifest.Manifest(manifest_args)
|
||||
generate_function(manifest, cutlass._nvcc_version)
|
||||
generate_function(manifest, cutlass_cppgen._nvcc_version)
|
||||
|
||||
if operation_kind not in manifest.operations:
|
||||
# No kernels generated for this architecture, this could be because the CUDA
|
||||
# toolkit is insufficient to support operations in this CC
|
||||
cutlass.logger.warning(f"No operations of type {operation_kind} found for CC {kernel_cc}")
|
||||
cutlass_cppgen.logger.warning(f"No operations of type {operation_kind} found for CC {kernel_cc}")
|
||||
return
|
||||
|
||||
# Only one CC should be returned, given the setup above of calling only the generation scripts
|
||||
@@ -358,7 +358,7 @@ class ArchOptions:
|
||||
|
||||
# Add FP8 A/B with FP32 C
|
||||
for type_comb in combinations_with_replacement(fp8_types, 2):
|
||||
types.append(type_comb + (cutlass.DataType.f32,))
|
||||
types.append(type_comb + (cutlass_cppgen.DataType.f32,))
|
||||
|
||||
layouts = [
|
||||
(cutlass_library.LayoutType.RowMajor, cutlass_library.LayoutType.RowMajor),
|
||||
@@ -444,7 +444,7 @@ class ArchOptions:
|
||||
:param layout_comb: tuple of data types for (layout_A, layout_B)
|
||||
:type layout_comb: tuple[cutlass_library.LayoutType]
|
||||
:param math_operation: math operation to consider or None if any can be considered
|
||||
:type math_operation: cutlass.MathOperation
|
||||
:type math_operation: cutlass_cppgen.MathOperation
|
||||
|
||||
:return: set of operation classes that support the provided data type and layout combination
|
||||
:rtype: set
|
||||
@@ -484,7 +484,7 @@ class ArchOptions:
|
||||
:param layout_b: layout of operand B
|
||||
:type layout_b: cutlass_library.LayoutType
|
||||
:param math_operation: math operation to consider
|
||||
:type math_operation: cutlass.MathOperation
|
||||
:type math_operation: cutlass_cppgen.MathOperation
|
||||
|
||||
:return: set of operation classes that support the provided data type combination
|
||||
:rtype: set
|
||||
@@ -524,7 +524,7 @@ class ArchOptions:
|
||||
:param layout_b: layout of operand B
|
||||
:type layout_b: cutlass_library.LayoutType
|
||||
:param math_operation: math operation to consider
|
||||
:type math_operation: cutlass.MathOperation
|
||||
:type math_operation: cutlass_cppgen.MathOperation
|
||||
|
||||
:return: container of kernels by alignment supported by the provided combination of parameters
|
||||
:rtype: KernelsForDataType
|
||||
|
||||
@@ -30,7 +30,7 @@
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from cutlass.op.conv import Conv2d, Conv2dFprop, Conv2dDgrad, Conv2dWgrad
|
||||
from cutlass.op.gemm import Gemm
|
||||
from cutlass.op.gemm_grouped import GroupedGemm
|
||||
from cutlass.op.op import OperationBase
|
||||
from cutlass_cppgen.op.conv import Conv2d, Conv2dFprop, Conv2dDgrad, Conv2dWgrad
|
||||
from cutlass_cppgen.op.gemm import Gemm
|
||||
from cutlass_cppgen.op.gemm_grouped import GroupedGemm
|
||||
from cutlass_cppgen.op.op import OperationBase
|
||||
|
||||
@@ -47,7 +47,7 @@
|
||||
.. code-block:: python
|
||||
|
||||
# A, B, C, and D are torch/numpy/cupy tensor objects
|
||||
plan = cutlass.op.Conv(A, B, C, D)
|
||||
plan = cutlass_cppgen.op.Conv(A, B, C, D)
|
||||
plan.run(stride=(1, 1), padding=(0, 0), dilation=(1, 1))
|
||||
|
||||
One can also use the interface by specifying data types of operands at construction
|
||||
@@ -57,11 +57,11 @@
|
||||
.. code-block:: python
|
||||
|
||||
# The following is shorthand for:
|
||||
# cutlass.op.Conv2d(kind="fprop",
|
||||
# cutlass_cppgen.op.Conv2d(kind="fprop",
|
||||
# element_A=torch.float32, element_B=torch.float32,
|
||||
# element_C=torch.float32, element_D=torch.float32,
|
||||
# element_accumulator=torch.float32)
|
||||
plan = cutlass.op.Conv2d(kind="fprop", element=torch.float32)
|
||||
plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=torch.float32)
|
||||
|
||||
A0 = torch.rand((128, 256), dtype=torch.float32, device='cuda')
|
||||
B0 = torch.rand((256, 64), dtype=torch.float32, device='cuda')
|
||||
@@ -81,7 +81,7 @@
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
plan = cutlass.op.Conv2d(kind="fprop", element=np.float32)
|
||||
plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=np.float32)
|
||||
|
||||
# Do other work...
|
||||
|
||||
@@ -96,15 +96,15 @@
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
plan = cutlass.op.Conv2d(kind="fprop", element=np.float32)
|
||||
plan.activation = cutlass.epilogue.relu
|
||||
plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=np.float32)
|
||||
plan.activation = cutlass_cppgen.epilogue.relu
|
||||
|
||||
Operations can also be run asynchronously:
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
plan = cutlass.op.Conv2d(kind="fprop", element=np.float32)
|
||||
plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=np.float32)
|
||||
args = plan.run()
|
||||
|
||||
# Do other work...
|
||||
@@ -114,7 +114,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Optional
|
||||
from cutlass.utils.lazy_import import lazy_import
|
||||
from cutlass_cppgen.utils.lazy_import import lazy_import
|
||||
cuda = lazy_import("cuda.cuda")
|
||||
cudart = lazy_import("cuda.cudart")
|
||||
from cutlass_library import (
|
||||
@@ -127,15 +127,15 @@ from cutlass_library import (
|
||||
StrideSupport,
|
||||
)
|
||||
|
||||
import cutlass
|
||||
from cutlass import epilogue
|
||||
from cutlass.backend import compiler
|
||||
from cutlass.backend.conv2d_operation import Conv2dArguments, Conv2dOperation
|
||||
from cutlass.backend.reduction_operation import ReductionOperation, ReductionArguments
|
||||
from cutlass.backend.library import TensorDescription, TileDescription
|
||||
from cutlass.op.op import OperationBase
|
||||
from cutlass.shape import Conv2DProblemSize, MatrixCoord
|
||||
from cutlass.utils import check, datatypes
|
||||
import cutlass_cppgen
|
||||
from cutlass_cppgen import epilogue
|
||||
from cutlass_cppgen.backend import compiler
|
||||
from cutlass_cppgen.backend.conv2d_operation import Conv2dArguments, Conv2dOperation
|
||||
from cutlass_cppgen.backend.reduction_operation import ReductionOperation, ReductionArguments
|
||||
from cutlass_cppgen.backend.library import TensorDescription, TileDescription
|
||||
from cutlass_cppgen.op.op import OperationBase
|
||||
from cutlass_cppgen.shape import Conv2DProblemSize, MatrixCoord
|
||||
from cutlass_cppgen.utils import check, datatypes
|
||||
|
||||
|
||||
class Conv2d(OperationBase):
|
||||
@@ -155,11 +155,11 @@ class Conv2d(OperationBase):
|
||||
# Use F32 for A, B, C, D, and accumulation in fprop
|
||||
|
||||
# Use the generic ``element`` parameter to concisely set all data types for operands to the same values.
|
||||
Conv2d(kind="fprop", element=cutlass.DataType.f32)
|
||||
Conv2d(kind="fprop", element=cutlass_cppgen.DataType.f32)
|
||||
|
||||
# Explicitly specify the data types to use for A, B, C, and D.
|
||||
Conv2d(kind="fprop", element_A=cutlass.DataType.f32, element_B=cutlass.DataType.f32,
|
||||
element_C=cutlass.DataType.f32, element_D=cutlass.DataType.f32)
|
||||
Conv2d(kind="fprop", element_A=cutlass_cppgen.DataType.f32, element_B=cutlass_cppgen.DataType.f32,
|
||||
element_C=cutlass_cppgen.DataType.f32, element_D=cutlass_cppgen.DataType.f32)
|
||||
|
||||
# Set the data types and elements from existing tensors. Note that one can use different tensors when
|
||||
# executing GEMM via the ``run()`` method than passed in here (though those passed in to ``run()`` must
|
||||
@@ -169,8 +169,8 @@ class Conv2d(OperationBase):
|
||||
|
||||
# Explicitly specify the data type for only some of A, B, C, and D. Unspecified data types will inherit
|
||||
# those passed in via the generic ``element``
|
||||
Conv2d(kind="fprop", element_A=cutlass.DataType.f32, element_accumulator=cutlass.DataType.f32,
|
||||
element=cutlass.DataType.f32)
|
||||
Conv2d(kind="fprop", element_A=cutlass_cppgen.DataType.f32, element_accumulator=cutlass_cppgen.DataType.f32,
|
||||
element=cutlass_cppgen.DataType.f32)
|
||||
|
||||
The order of precedence for the setting of the data type for a given operand/output is as follows:
|
||||
1) If the tensor type is specified (e.g., ``A``), use the data type inferred from this tensor
|
||||
@@ -186,17 +186,17 @@ class Conv2d(OperationBase):
|
||||
:param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
|
||||
:param beta: scalar parameter beta from GEMM operation that scales operand C
|
||||
:param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type
|
||||
:type element: cutlass.DataType
|
||||
:type element: cutlass_cppgen.DataType
|
||||
:param element_A: data type to be used for operand A
|
||||
:type element_A: cutlass.DataType
|
||||
:type element_A: cutlass_cppgen.DataType
|
||||
:param element_B: data type to be used for operand B
|
||||
:type element_B: cutlass.DataType
|
||||
:type element_B: cutlass_cppgen.DataType
|
||||
:param element_C: data type to be used for operand C
|
||||
:type element_C: cutlass.DataType
|
||||
:type element_C: cutlass_cppgen.DataType
|
||||
:param element_D: data type to be used for operand D
|
||||
:type element_D: cutlass.DataType
|
||||
:type element_D: cutlass_cppgen.DataType
|
||||
:param element_accumulator: data type to be used in accumulation of the product of operands A and B
|
||||
:type element_accumulator: cutlass.DataType
|
||||
:type element_accumulator: cutlass_cppgen.DataType
|
||||
:param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90
|
||||
:type cc: int
|
||||
:param kernel_cc: compute capability of kernels to generate. For example, if running on SM90, but desiring to use a CUTLASS 2.x-style Ampere kernel, this should be set to 80
|
||||
@@ -215,7 +215,7 @@ class Conv2d(OperationBase):
|
||||
if self.current_cc == 90:
|
||||
# The Conv2d kernel on Hopper (SM90) is currently unsupported
|
||||
# Revert to use SM80-tagged kernels
|
||||
cutlass.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.")
|
||||
cutlass_cppgen.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.")
|
||||
self.specified_kernel_cc = 80
|
||||
self._reset_options(80)
|
||||
|
||||
@@ -250,7 +250,7 @@ class Conv2d(OperationBase):
|
||||
assert elt_to_set is not None
|
||||
|
||||
# Currently we only support layout TensorNHWC
|
||||
lay_to_set = cutlass.LayoutType.TensorNHWC
|
||||
lay_to_set = cutlass_cppgen.LayoutType.TensorNHWC
|
||||
elements.append(datatypes.library_type(elt_to_set))
|
||||
layouts.append(lay_to_set)
|
||||
|
||||
@@ -301,10 +301,10 @@ class Conv2d(OperationBase):
|
||||
self._layout_a, self._layout_b, self._math_operation
|
||||
)
|
||||
|
||||
if cutlass.OpcodeClass.TensorOp in self.possible_op_classes:
|
||||
self.opclass = cutlass.OpcodeClass.TensorOp
|
||||
elif cutlass.OpcodeClass.Simt in self.possible_op_classes:
|
||||
self.opclass = cutlass.OpcodeClass.Simt
|
||||
if cutlass_cppgen.OpcodeClass.TensorOp in self.possible_op_classes:
|
||||
self.opclass = cutlass_cppgen.OpcodeClass.TensorOp
|
||||
elif cutlass_cppgen.OpcodeClass.Simt in self.possible_op_classes:
|
||||
self.opclass = cutlass_cppgen.OpcodeClass.Simt
|
||||
else:
|
||||
if self._math_operation is not None:
|
||||
math_op_str = f' and math operation {self._math_operation}'
|
||||
@@ -342,7 +342,7 @@ class Conv2d(OperationBase):
|
||||
Set the tile description
|
||||
|
||||
:param td: tile description
|
||||
:type td: cutlass.backend.TileDescription, or a dict with keys
|
||||
:type td: cutlass_cppgen.backend.TileDescription, or a dict with keys
|
||||
{
|
||||
"threadblock_shape": [int, int, int],
|
||||
"warp_count": [int, int, int],
|
||||
@@ -359,7 +359,7 @@ class Conv2d(OperationBase):
|
||||
self._tile_description = datatypes.td_from_profiler_op(op)
|
||||
if "cluster_shape" in td.keys():
|
||||
if td["cluster_shape"] != [1, 1, 1]:
|
||||
cutlass.logger.warning("Conv2d currently only support 'cluster_shape'=[1, 1, 1]'.")
|
||||
cutlass_cppgen.logger.warning("Conv2d currently only support 'cluster_shape'=[1, 1, 1]'.")
|
||||
td["cluster_shape"] = [1, 1, 1]
|
||||
td = self._tile_description.clone_and_update(td)
|
||||
|
||||
@@ -381,7 +381,7 @@ class Conv2d(OperationBase):
|
||||
- Is the kernel schedule being used supported on the architecture in question?
|
||||
|
||||
:param td: tile description to validate
|
||||
:type td: cutlass.backend.TileDescription
|
||||
:type td: cutlass_cppgen.backend.TileDescription
|
||||
:return: tuple in which the first element is a bool indicating that the tile description is valid
|
||||
and the second element is a string providing an optional error message.
|
||||
:rtype: tuple
|
||||
@@ -445,9 +445,9 @@ class Conv2d(OperationBase):
|
||||
"""
|
||||
if self.conv_kind == ConvKind.Dgrad:
|
||||
if stride[0] != 1 or stride[1] != 1:
|
||||
return getattr(cutlass.swizzle, f"StridedDgradIdentitySwizzle{self._swizzling_stride}")
|
||||
return getattr(cutlass_cppgen.swizzle, f"StridedDgradIdentitySwizzle{self._swizzling_stride}")
|
||||
|
||||
return getattr(cutlass.swizzle, f"IdentitySwizzle{self._swizzling_stride}")
|
||||
return getattr(cutlass_cppgen.swizzle, f"IdentitySwizzle{self._swizzling_stride}")
|
||||
|
||||
#
|
||||
# Iterator Algorithm Related
|
||||
@@ -546,14 +546,14 @@ class Conv2d(OperationBase):
|
||||
self, tile_description: TileDescription = None,
|
||||
alignment_A: int = None, alignment_B: int = None, alignment_C: int = None,
|
||||
iterator_algorithm: IteratorAlgorithm = None,
|
||||
stride_support = None, swizzling_functor: cutlass.swizzle = None,
|
||||
epilogue_functor=None) -> cutlass.backend.Conv2dOperation:
|
||||
stride_support = None, swizzling_functor: cutlass_cppgen.swizzle = None,
|
||||
epilogue_functor=None) -> cutlass_cppgen.backend.Conv2dOperation:
|
||||
"""
|
||||
Constructs a ``cutlass.backend.Conv2dOperation`` based on the input parameters and current
|
||||
Constructs a ``cutlass_cppgen.backend.Conv2dOperation`` based on the input parameters and current
|
||||
kernel specification of the ``Conv2d`` object.
|
||||
|
||||
:param tile_description: tile description specifying shapes and operand types to use in the kernel
|
||||
:type tile_description: cutlass.backend.TileDescription
|
||||
:type tile_description: cutlass_cppgen.backend.TileDescription
|
||||
:param alignment_A: alignment of operand A
|
||||
:type alignment_A: int
|
||||
:param alignment_B: alignment of operand B
|
||||
@@ -565,11 +565,11 @@ class Conv2d(OperationBase):
|
||||
:param stride_support: the stride support of dgrad
|
||||
:type stride_support: cutlass_library.library.StrideSupport
|
||||
:param swizzling_functor: the swizzling functor
|
||||
:type swizzling_functor: cutlass.swizzle
|
||||
:type swizzling_functor: cutlass_cppgen.swizzle
|
||||
:param epilogue_functor: the epilogue functor
|
||||
|
||||
:return: operation that was constructed
|
||||
:rtype: cutlass.backend.Conv2dOperation
|
||||
:rtype: cutlass_cppgen.backend.Conv2dOperation
|
||||
"""
|
||||
# Get alignment
|
||||
alignment_A = check.alignment_or_default(alignment_A, self.alignment_pref_A)
|
||||
@@ -637,8 +637,8 @@ class Conv2d(OperationBase):
|
||||
def compile(self, tile_description: TileDescription = None,
|
||||
alignment_A: int = None, alignment_B: int = None, alignment_C: int = None,
|
||||
iterator_algorithm: IteratorAlgorithm = None,
|
||||
stride_support = None, swizzling_functor: cutlass.swizzle = None,
|
||||
epilogue_functor = None, print_module: bool = False) -> cutlass.backend.Conv2dOperation:
|
||||
stride_support = None, swizzling_functor: cutlass_cppgen.swizzle = None,
|
||||
epilogue_functor = None, print_module: bool = False) -> cutlass_cppgen.backend.Conv2dOperation:
|
||||
"""
|
||||
Emits and compiles the kernel currently specified. If ``tile_description`` and any
|
||||
of the ``alignment`` parameters are set, the kernel will be chosen using this
|
||||
@@ -646,7 +646,7 @@ class Conv2d(OperationBase):
|
||||
will be used.
|
||||
|
||||
::param tile_description: tile description specifying shapes and operand types to use in the kernel
|
||||
:type tile_description: cutlass.backend.TileDescription
|
||||
:type tile_description: cutlass_cppgen.backend.TileDescription
|
||||
:param alignment_A: alignment of operand A
|
||||
:type alignment_A: int
|
||||
:param alignment_B: alignment of operand B
|
||||
@@ -658,11 +658,11 @@ class Conv2d(OperationBase):
|
||||
:param stride_support: the stride support of dgrad
|
||||
:type stride_support: cutlass_library.library.StrideSupport
|
||||
:param swizzling_functor: the swizzling functor
|
||||
:type swizzling_functor: cutlass.swizzle
|
||||
:type swizzling_functor: cutlass_cppgen.swizzle
|
||||
:param epilogue_functor: the epilogue functor
|
||||
|
||||
:return: operation that was compiled
|
||||
:rtype: cutlass.backend.Conv2dOperation
|
||||
:rtype: cutlass_cppgen.backend.Conv2dOperation
|
||||
"""
|
||||
|
||||
self.operation = self.construct(
|
||||
@@ -770,7 +770,7 @@ class Conv2d(OperationBase):
|
||||
:type stream: :class:`cuda.cuda.CUstream`
|
||||
|
||||
:return: arguments passed in to the kernel
|
||||
:rtype: cutlass.backend.Conv2dArguments
|
||||
:rtype: cutlass_cppgen.backend.Conv2dArguments
|
||||
"""
|
||||
if not stream:
|
||||
stream = cuda.CUstream(0)
|
||||
|
||||
@@ -47,7 +47,7 @@
|
||||
.. code-block:: python
|
||||
|
||||
# A, B, C, and D are torch/numpy/cupy tensor objects
|
||||
plan = cutlass.op.Gemm(A, B, C, D)
|
||||
plan = cutlass_cppgen.op.Gemm(A, B, C, D)
|
||||
plan.run()
|
||||
|
||||
|
||||
@@ -58,11 +58,11 @@
|
||||
.. code-block:: python
|
||||
|
||||
# The following is shorthand for:
|
||||
# cutlass.op.Gemm(element_A=torch.float32, element_B=torch.float32,
|
||||
# cutlass_cppgen.op.Gemm(element_A=torch.float32, element_B=torch.float32,
|
||||
# element_C=torch.float32, element_D=torch.float32,
|
||||
# element_accumulator=torch.float32,
|
||||
# layout=cutlass.LayoutType.RowMajor)
|
||||
plan = cutlass.op.Gemm(element=torch.float32, layout=cutlass.LayoutType.RowMajor)
|
||||
# layout=cutlass_cppgen.LayoutType.RowMajor)
|
||||
plan = cutlass_cppgen.op.Gemm(element=torch.float32, layout=cutlass_cppgen.LayoutType.RowMajor)
|
||||
|
||||
A0 = torch.rand((128, 256), device='cuda')
|
||||
B0 = torch.rand((256, 64), device='cuda')
|
||||
@@ -82,7 +82,7 @@
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
plan = cutlass.op.Gemm(element=np.float32, layout=cutlass.LayoutType.RowMajor)
|
||||
plan = cutlass_cppgen.op.Gemm(element=np.float32, layout=cutlass_cppgen.LayoutType.RowMajor)
|
||||
plan.compile()
|
||||
|
||||
# Do other work...
|
||||
@@ -98,15 +98,15 @@
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
plan = cutlass.op.Gemm(element=np.float32, layout=cutlass.LayoutType.RowMajor)
|
||||
plan.activation = cutlass.epilogue.relu
|
||||
plan = cutlass_cppgen.op.Gemm(element=np.float32, layout=cutlass_cppgen.LayoutType.RowMajor)
|
||||
plan.activation = cutlass_cppgen.epilogue.relu
|
||||
|
||||
Operations can also be run asynchronously:
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
plan = cutlass.op.Gemm(element=np.float32, layout=cutlass.LayoutType.RowMajor)
|
||||
plan = cutlass_cppgen.op.Gemm(element=np.float32, layout=cutlass_cppgen.LayoutType.RowMajor)
|
||||
args = plan.run()
|
||||
|
||||
# Do other work...
|
||||
@@ -117,7 +117,7 @@ from __future__ import annotations
|
||||
from typing import Optional
|
||||
from math import prod
|
||||
|
||||
from cutlass.utils.lazy_import import lazy_import
|
||||
from cutlass_cppgen.utils.lazy_import import lazy_import
|
||||
cuda = lazy_import("cuda.cuda")
|
||||
from cutlass_library import (
|
||||
DataType,
|
||||
@@ -125,15 +125,15 @@ from cutlass_library import (
|
||||
GemmUniversalMode,
|
||||
)
|
||||
|
||||
import cutlass
|
||||
from cutlass import epilogue, swizzle
|
||||
from cutlass.backend import compiler
|
||||
from cutlass.backend.evt import EpilogueFunctorVisitor
|
||||
from cutlass.backend.gemm_operation import GemmArguments, GemmOperationUniversal
|
||||
from cutlass.backend.library import TensorDescription, TileDescription
|
||||
from cutlass.op.op import OperationBase
|
||||
from cutlass.shape import GemmCoord
|
||||
from cutlass.utils import check, datatypes
|
||||
import cutlass_cppgen
|
||||
from cutlass_cppgen import epilogue, swizzle
|
||||
from cutlass_cppgen.backend import compiler
|
||||
from cutlass_cppgen.backend.evt import EpilogueFunctorVisitor
|
||||
from cutlass_cppgen.backend.gemm_operation import GemmArguments, GemmOperationUniversal
|
||||
from cutlass_cppgen.backend.library import TensorDescription, TileDescription
|
||||
from cutlass_cppgen.op.op import OperationBase
|
||||
from cutlass_cppgen.shape import GemmCoord
|
||||
from cutlass_cppgen.utils import check, datatypes
|
||||
|
||||
|
||||
class Gemm(OperationBase):
|
||||
@@ -154,11 +154,11 @@ class Gemm(OperationBase):
|
||||
|
||||
# Use the generic ``element`` and ``layout`` parameters to concisely set all data types and layouts
|
||||
# for operands to the same values.
|
||||
Gemm(element=cutlass.DataType.f32, layout=cutlass.LayoutType.RowMajor)
|
||||
Gemm(element=cutlass_cppgen.DataType.f32, layout=cutlass_cppgen.LayoutType.RowMajor)
|
||||
|
||||
# Explicitly specify the data types to use for A, B, C, and D. Use the generic ``layout``.
|
||||
Gemm(element_A=cutlass.DataType.f32, element_B=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
|
||||
element_D=cutlass.DataType.f32, layout=cutlass.LayoutType.RowMajor)
|
||||
Gemm(element_A=cutlass_cppgen.DataType.f32, element_B=cutlass_cppgen.DataType.f32, element_C=cutlass_cppgen.DataType.f32,
|
||||
element_D=cutlass_cppgen.DataType.f32, layout=cutlass_cppgen.LayoutType.RowMajor)
|
||||
|
||||
# Set the data types and elements from existing tensors. Note that one can use different tensors when
|
||||
# executing GEMM via the ``run()`` method than passed in here (though those passed in to ``run()`` must
|
||||
@@ -168,13 +168,13 @@ class Gemm(OperationBase):
|
||||
|
||||
# Use the generic ``element`` and explicitly specify the layouts to use for A, B, and C (layout of D is
|
||||
# the same as that for D, at present)
|
||||
Gemm(element=cutlass.DataType.f32, layout_A=cutlass.LayoutType.RowMajor,
|
||||
layout_B=cutlass.LayoutType.RowMajor, layout_C=cutlass.LayoutType.RowMajor)
|
||||
Gemm(element=cutlass_cppgen.DataType.f32, layout_A=cutlass_cppgen.LayoutType.RowMajor,
|
||||
layout_B=cutlass_cppgen.LayoutType.RowMajor, layout_C=cutlass_cppgen.LayoutType.RowMajor)
|
||||
|
||||
# Explicitly specify the data type and layout for only some of A, B, C, and D. Unspecified data types
|
||||
# and layouts will inherit those passed in via the generic ``element`` and ``layout``
|
||||
Gemm(element_A=cutlass.DataType.f32, layout_B=cutlass.LayoutType.RowMajor,
|
||||
element=cutlass.DataType.f32, layout=cutlass.LayoutType.RowMajor)
|
||||
Gemm(element_A=cutlass_cppgen.DataType.f32, layout_B=cutlass_cppgen.LayoutType.RowMajor,
|
||||
element=cutlass_cppgen.DataType.f32, layout=cutlass_cppgen.LayoutType.RowMajor)
|
||||
|
||||
The order of precedence for the setting of the data type and layout for a given operand/output is as follows:
|
||||
1) If the tensor type is specified (e.g., ``A``), use the data type and layout inferred from this tensor
|
||||
@@ -192,27 +192,27 @@ class Gemm(OperationBase):
|
||||
:param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
|
||||
:param beta: scalar parameter beta from GEMM operation that scales operand C
|
||||
:param element_accumulator: data type to be used in accumulation of the product of operands A and B
|
||||
:type element_accumulator: cutlass.DataType
|
||||
:type element_accumulator: cutlass_cppgen.DataType
|
||||
:param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type
|
||||
:type element: cutlass.DataType
|
||||
:type element: cutlass_cppgen.DataType
|
||||
:param layout: generic layout type to be used for operands A, B, C, and D
|
||||
:type layout: cutlass.LayoutType
|
||||
:type layout: cutlass_cppgen.LayoutType
|
||||
:param element_A: data type to be used for operand A
|
||||
:type element_A: cutlass.DataType
|
||||
:type element_A: cutlass_cppgen.DataType
|
||||
:param element_B: data type to be used for operand B
|
||||
:type element_B: cutlass.DataType
|
||||
:type element_B: cutlass_cppgen.DataType
|
||||
:param element_C: data type to be used for operand C
|
||||
:type element_C: cutlass.DataType
|
||||
:type element_C: cutlass_cppgen.DataType
|
||||
:param element_D: data type to be used for operand D
|
||||
:type element_D: cutlass.DataType
|
||||
:type element_D: cutlass_cppgen.DataType
|
||||
:param layout_A: layout of operand A
|
||||
:type layout_A: cutlass.LayoutType
|
||||
:type layout_A: cutlass_cppgen.LayoutType
|
||||
:param layout_B: layout of operand B
|
||||
:type layout_B: cutlass.LayoutType
|
||||
:type layout_B: cutlass_cppgen.LayoutType
|
||||
:param layout_C: layout of operand C
|
||||
:type layout_C: cutlass.LayoutType
|
||||
:type layout_C: cutlass_cppgen.LayoutType
|
||||
:param layout_D: layout of operand D
|
||||
:type layout_D: cutlass.LayoutType
|
||||
:type layout_D: cutlass_cppgen.LayoutType
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -278,7 +278,7 @@ class Gemm(OperationBase):
|
||||
|
||||
self._reset_operations()
|
||||
|
||||
self._swizzling_functor = cutlass.swizzle.IdentitySwizzle1
|
||||
self._swizzling_functor = cutlass_cppgen.swizzle.IdentitySwizzle1
|
||||
|
||||
def _reset_operations(self, reset_epilogue: bool = True):
|
||||
# Set the default op class
|
||||
@@ -289,10 +289,10 @@ class Gemm(OperationBase):
|
||||
self._element_a, self._element_b, self._element_accumulator,
|
||||
self._layout_a, self._layout_b, self._math_operation)
|
||||
|
||||
if cutlass.OpcodeClass.TensorOp in self.possible_op_classes:
|
||||
self.opclass = cutlass.OpcodeClass.TensorOp
|
||||
elif cutlass.OpcodeClass.Simt in self.possible_op_classes:
|
||||
self.opclass = cutlass.OpcodeClass.Simt
|
||||
if cutlass_cppgen.OpcodeClass.TensorOp in self.possible_op_classes:
|
||||
self.opclass = cutlass_cppgen.OpcodeClass.TensorOp
|
||||
elif cutlass_cppgen.OpcodeClass.Simt in self.possible_op_classes:
|
||||
self.opclass = cutlass_cppgen.OpcodeClass.Simt
|
||||
else:
|
||||
if self._math_operation is not None:
|
||||
math_op_str = f' and math operation {self._math_operation}'
|
||||
@@ -303,7 +303,7 @@ class Gemm(OperationBase):
|
||||
f'combination {datatype_comb}x{layout_comb}{math_op_str}')
|
||||
|
||||
if reset_epilogue:
|
||||
self._reset_epilogue_functor_activation(cutlass.epilogue.identity)
|
||||
self._reset_epilogue_functor_activation(cutlass_cppgen.epilogue.identity)
|
||||
|
||||
@property
|
||||
def swizzling_functor(self):
|
||||
@@ -319,8 +319,8 @@ class Gemm(OperationBase):
|
||||
"""
|
||||
Sets the swizzling functor to the type specified by `swizzling_functor`
|
||||
"""
|
||||
if swizzling_functor == cutlass.swizzle.ThreadblockSwizzleStreamK:
|
||||
if self.op_class == cutlass.OpcodeClass.Simt:
|
||||
if swizzling_functor == cutlass_cppgen.swizzle.ThreadblockSwizzleStreamK:
|
||||
if self.op_class == cutlass_cppgen.OpcodeClass.Simt:
|
||||
raise Exception('ThreadblockSwizzleStreamK is currently only supported with opcode class TensorOp')
|
||||
|
||||
if self.current_cc == 90:
|
||||
@@ -345,7 +345,7 @@ class Gemm(OperationBase):
|
||||
Set the tile description
|
||||
|
||||
:param td: tile description
|
||||
:type td: cutlass.backend.TileDescription, or a dict with keys
|
||||
:type td: cutlass_cppgen.backend.TileDescription, or a dict with keys
|
||||
{
|
||||
"threadblock_shape": [int, int, int],
|
||||
"warp_count": [int, int, int],
|
||||
@@ -380,7 +380,7 @@ class Gemm(OperationBase):
|
||||
- Is the kernel schedule being used supported on the architecture in question?
|
||||
|
||||
:param td: tile description to validate
|
||||
:type td: cutlass.backend.TileDescription
|
||||
:type td: cutlass_cppgen.backend.TileDescription
|
||||
:return: tuple in which the first element is a bool indicating that the tile description is valid
|
||||
and the second element is a string providing an optional error message.
|
||||
:rtype: tuple
|
||||
@@ -412,11 +412,11 @@ class Gemm(OperationBase):
|
||||
self, tile_description: TileDescription = None,
|
||||
alignment_A: int = None, alignment_B: int = None, alignment_C: int = None) -> GemmOperationUniversal:
|
||||
"""
|
||||
Constructs a ``cutlass.backend.GemmUniversalOperation`` based on the input parameters and current
|
||||
Constructs a ``cutlass_cppgen.backend.GemmUniversalOperation`` based on the input parameters and current
|
||||
kernel specification of the ``Gemm`` object.
|
||||
|
||||
:param tile_description: tile description specifying shapes and operand types to use in the kernel
|
||||
:type tile_description: cutlass.backend.TileDescription
|
||||
:type tile_description: cutlass_cppgen.backend.TileDescription
|
||||
:param alignment_A: alignment of operand A
|
||||
:type alignment_A: int
|
||||
:param alignment_B: alignment of operand B
|
||||
@@ -425,7 +425,7 @@ class Gemm(OperationBase):
|
||||
:type alignment_C: int
|
||||
|
||||
:return: operation that was constructed
|
||||
:rtype: cutlass.backend.GemmOperationUniversal
|
||||
:rtype: cutlass_cppgen.backend.GemmOperationUniversal
|
||||
"""
|
||||
alignment_pref_A = min(128 // DataTypeSize[self._element_a], max(self.possible_operations.alignments("A")))
|
||||
alignment_pref_B = min(128 // DataTypeSize[self._element_b], max(self.possible_operations.alignments("B")))
|
||||
@@ -471,7 +471,7 @@ class Gemm(OperationBase):
|
||||
|
||||
def compile(self, tile_description: TileDescription = None,
|
||||
alignment_A: int = None, alignment_B: int = None, alignment_C: int = None,
|
||||
print_module: bool = False) -> cutlass.backend.GemmOperationUniversal:
|
||||
print_module: bool = False) -> cutlass_cppgen.backend.GemmOperationUniversal:
|
||||
"""
|
||||
Emits and compiles the kernel currently specified. If ``tile_description`` and any
|
||||
of the ``alignment`` parameters are set, the kernel will be chosen using this
|
||||
@@ -479,7 +479,7 @@ class Gemm(OperationBase):
|
||||
will be used.
|
||||
|
||||
:param tile_description: tile description specifying shapes and operand types to use in the kernel
|
||||
:type tile_description: cutlass.backend.TileDescription
|
||||
:type tile_description: cutlass_cppgen.backend.TileDescription
|
||||
:param alignment_A: alignment of operand A
|
||||
:type alignment_A: int
|
||||
:param alignment_B: alignment of operand B
|
||||
@@ -490,7 +490,7 @@ class Gemm(OperationBase):
|
||||
:type print_module: bool
|
||||
|
||||
:return: operation that was compiled
|
||||
:rtype: cutlass.backend.GemmOperationUniversal
|
||||
:rtype: cutlass_cppgen.backend.GemmOperationUniversal
|
||||
"""
|
||||
self.operation = self.construct(tile_description, alignment_A, alignment_B, alignment_C)
|
||||
|
||||
@@ -566,7 +566,7 @@ class Gemm(OperationBase):
|
||||
:param D: tensor D
|
||||
:type D: numpy/cupy/torch array/tensor object
|
||||
|
||||
:return: tuple containing the problem size (cutlass.shape.GemmCoord), the GEMM mode (cutlass.GemmUniversalMode), and the batch count (int)
|
||||
:return: tuple containing the problem size (cutlass_cppgen.shape.GemmCoord), the GEMM mode (cutlass_cppgen.GemmUniversalMode), and the batch count (int)
|
||||
:rtype: tuple
|
||||
"""
|
||||
M, K = A.shape[-2:]
|
||||
@@ -582,9 +582,9 @@ class Gemm(OperationBase):
|
||||
# and C are row major. A similar operation can be performed if only B has a nonzero
|
||||
# batch dimension
|
||||
if batch_count > 1:
|
||||
A_row = self._layout_a == cutlass.LayoutType.RowMajor
|
||||
B_row = self._layout_b == cutlass.LayoutType.RowMajor
|
||||
C_row = self._layout_c == cutlass.LayoutType.RowMajor
|
||||
A_row = self._layout_a == cutlass_cppgen.LayoutType.RowMajor
|
||||
B_row = self._layout_b == cutlass_cppgen.LayoutType.RowMajor
|
||||
C_row = self._layout_c == cutlass_cppgen.LayoutType.RowMajor
|
||||
|
||||
# Consider a Tensor to be batched if its rank is > 2 and
|
||||
# the product of the modes beyond rank 2 equals our pre-determined batch size.
|
||||
@@ -652,7 +652,7 @@ class Gemm(OperationBase):
|
||||
:type stream: :class:`cuda.cuda.CUstream`
|
||||
|
||||
:return: arguments passed in to the kernel
|
||||
:rtype: cutlass.backend.GemmArguments
|
||||
:rtype: cutlass_cppgen.backend.GemmArguments
|
||||
"""
|
||||
if not stream:
|
||||
stream = cuda.CUstream(0)
|
||||
|
||||
@@ -47,27 +47,27 @@
|
||||
.. code-block:: python
|
||||
|
||||
# As, Bs, Cs, and Ds are torch/numpy/cupy tensor objects
|
||||
plan = cutlass.op.GroupedGemm(element=cutlass.DataType.f16, layout=cutlass.LayoutType.RowMajor)
|
||||
plan = cutlass_cppgen.op.GroupedGemm(element=cutlass_cppgen.DataType.f16, layout=cutlass_cppgen.LayoutType.RowMajor)
|
||||
plan.run([A0, A1], [B0, B1], [C0, C1], [D0, D1])
|
||||
"""
|
||||
from __future__ import annotations
|
||||
from typing import Optional
|
||||
from cutlass_library import DataTypeSize
|
||||
|
||||
from cutlass.utils.lazy_import import lazy_import
|
||||
from cutlass_cppgen.utils.lazy_import import lazy_import
|
||||
cuda = lazy_import("cuda.cuda")
|
||||
from cutlass.backend.gemm_operation import (
|
||||
from cutlass_cppgen.backend.gemm_operation import (
|
||||
GemmGroupedArguments,
|
||||
GemmOperationGrouped,
|
||||
)
|
||||
from cutlass.backend.library import (
|
||||
from cutlass_cppgen.backend.library import (
|
||||
SchedulerMode,
|
||||
TensorDescription,
|
||||
TileDescription,
|
||||
)
|
||||
from cutlass.op.gemm import Gemm
|
||||
from cutlass.shape import GemmCoord
|
||||
from cutlass.utils import check, datatypes
|
||||
from cutlass_cppgen.op.gemm import Gemm
|
||||
from cutlass_cppgen.shape import GemmCoord
|
||||
from cutlass_cppgen.utils import check, datatypes
|
||||
|
||||
|
||||
class GroupedGemm(Gemm):
|
||||
@@ -90,27 +90,27 @@ class GroupedGemm(Gemm):
|
||||
:param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
|
||||
:param beta: scalar parameter beta from GEMM operation that scales operand C
|
||||
:param element_accumulator: data type to be used in accumulation of the product of operands A and B
|
||||
:type element_accumulator: cutlass.DataType
|
||||
:type element_accumulator: cutlass_cppgen.DataType
|
||||
:param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type
|
||||
:type element: cutlass.DataType
|
||||
:type element: cutlass_cppgen.DataType
|
||||
:param layout: generic layout type to be used for operands A, B, C, and D
|
||||
:type layout: cutlass.LayoutType
|
||||
:type layout: cutlass_cppgen.LayoutType
|
||||
:param element_A: data type to be used for operand A
|
||||
:type element_A: cutlass.DataType
|
||||
:type element_A: cutlass_cppgen.DataType
|
||||
:param element_B: data type to be used for operand B
|
||||
:type element_B: cutlass.DataType
|
||||
:type element_B: cutlass_cppgen.DataType
|
||||
:param element_C: data type to be used for operand C
|
||||
:type element_C: cutlass.DataType
|
||||
:type element_C: cutlass_cppgen.DataType
|
||||
:param element_D: data type to be used for operand D
|
||||
:type element_D: cutlass.DataType
|
||||
:type element_D: cutlass_cppgen.DataType
|
||||
:type layout_A: layout of operand A
|
||||
:param layout_A: cutlass.LayoutType
|
||||
:param layout_A: cutlass_cppgen.LayoutType
|
||||
:type layout_B: layout of operand B
|
||||
:param layout_B: cutlass.LayoutType
|
||||
:param layout_B: cutlass_cppgen.LayoutType
|
||||
:type layout_C: layout of operand C
|
||||
:param layout_C: cutlass.LayoutType
|
||||
:param layout_C: cutlass_cppgen.LayoutType
|
||||
:type layout_D: layout of operand D
|
||||
:param layout_D: cutlass.LayoutType
|
||||
:param layout_D: cutlass_cppgen.LayoutType
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -151,11 +151,11 @@ class GroupedGemm(Gemm):
|
||||
alignment_B: int = None,
|
||||
alignment_C: int = None) -> GemmOperationGrouped:
|
||||
"""
|
||||
Constructs a ``cutlass.backend.GemmOperationGrouped`` based on the input parameters and current
|
||||
Constructs a ``cutlass_cppgen.backend.GemmOperationGrouped`` based on the input parameters and current
|
||||
kernel specification of the ``Gemm`` object.
|
||||
|
||||
:param tile_description: tile description specifying shapes and operand types to use in the kernel
|
||||
:type tile_description: cutlass.backend.TileDescription
|
||||
:type tile_description: cutlass_cppgen.backend.TileDescription
|
||||
:param alignment_A: alignment of operand A
|
||||
:type alignment_A: int
|
||||
:param alignment_B: alignment of operand B
|
||||
@@ -164,7 +164,7 @@ class GroupedGemm(Gemm):
|
||||
:type alignment_C: int
|
||||
|
||||
:return: operation that was constructed
|
||||
:rtype: cutlass.backend.GemmOperationGrouped
|
||||
:rtype: cutlass_cppgen.backend.GemmOperationGrouped
|
||||
"""
|
||||
alignment_A = check.alignment_or_default(alignment_A, max(self.possible_operations.alignments("A")))
|
||||
alignment_B = check.alignment_or_default(alignment_B, max(self.possible_operations.alignments("B")))
|
||||
@@ -225,7 +225,7 @@ class GroupedGemm(Gemm):
|
||||
:type stream: :class:`cuda.cuda.CUstream`
|
||||
|
||||
:return: arguments passed in to the kernel
|
||||
:rtype: cutlass.backend.GemmGroupedArguments
|
||||
:rtype: cutlass_cppgen.backend.GemmGroupedArguments
|
||||
"""
|
||||
if not stream:
|
||||
stream = cuda.CUstream(0)
|
||||
|
||||
@@ -44,14 +44,14 @@ from cutlass_library import (
|
||||
SharedMemPerCC
|
||||
)
|
||||
|
||||
import cutlass
|
||||
from cutlass import get_option_registry
|
||||
from cutlass.backend.evt import EpilogueFunctorVisitor
|
||||
from cutlass.backend.utils.device import device_cc
|
||||
from cutlass.epilogue import get_activations, get_activation_epilogue, identity
|
||||
from cutlass.library_defaults import KernelsForDataType, _generator_ccs
|
||||
from cutlass.swizzle import get_swizzling_functors
|
||||
from cutlass.utils import datatypes, check
|
||||
import cutlass_cppgen
|
||||
from cutlass_cppgen import get_option_registry
|
||||
from cutlass_cppgen.backend.evt import EpilogueFunctorVisitor
|
||||
from cutlass_cppgen.backend.utils.device import device_cc
|
||||
from cutlass_cppgen.epilogue import get_activations, get_activation_epilogue, identity
|
||||
from cutlass_cppgen.library_defaults import KernelsForDataType, _generator_ccs
|
||||
from cutlass_cppgen.swizzle import get_swizzling_functors
|
||||
from cutlass_cppgen.utils import datatypes, check
|
||||
|
||||
|
||||
class OperationBase:
|
||||
@@ -205,19 +205,19 @@ class OperationBase:
|
||||
return tensor
|
||||
|
||||
@property
|
||||
def opclass(self) -> cutlass.OpcodeClass:
|
||||
def opclass(self) -> cutlass_cppgen.OpcodeClass:
|
||||
"""
|
||||
Returns the opcode class currently in use
|
||||
|
||||
:return: opcode class currently in use
|
||||
:rtype: cutlass.OpcodeClass
|
||||
:rtype: cutlass_cppgen.OpcodeClass
|
||||
"""
|
||||
return self.op_class
|
||||
|
||||
@opclass.setter
|
||||
def opclass(self, oc: cutlass.OpcodeClass):
|
||||
def opclass(self, oc: cutlass_cppgen.OpcodeClass):
|
||||
if isinstance(oc, str):
|
||||
oc = datatypes.getattr_enum(cutlass.OpcodeClass, oc)
|
||||
oc = datatypes.getattr_enum(cutlass_cppgen.OpcodeClass, oc)
|
||||
if oc in self.possible_op_classes:
|
||||
self.op_class = oc
|
||||
else:
|
||||
@@ -236,25 +236,25 @@ class OperationBase:
|
||||
self.epilogue_functor = self._reset_epilogue_functor_alignment(self._elements_per_access(), self.epilogue_functor)
|
||||
|
||||
@property
|
||||
def math_operation(self) -> cutlass.MathOperation:
|
||||
def math_operation(self) -> cutlass_cppgen.MathOperation:
|
||||
"""
|
||||
Returns the math operation currently in use
|
||||
|
||||
:return: math operation currently in use
|
||||
:rtype: cutlass.MathOperation
|
||||
:rtype: cutlass_cppgen.MathOperation
|
||||
"""
|
||||
return self._math_operation
|
||||
|
||||
@math_operation.setter
|
||||
def math_operation(self, mo: cutlass.MathOperation):
|
||||
def math_operation(self, mo: cutlass_cppgen.MathOperation):
|
||||
if isinstance(mo, str):
|
||||
mo = datatypes.getattr_enum(cutlass.MathOperation, mo)
|
||||
mo = datatypes.getattr_enum(cutlass_cppgen.MathOperation, mo)
|
||||
|
||||
if not self.specified_kernel_cc:
|
||||
if self.current_cc == 90:
|
||||
# CUTLASS 3.0 kernels do not use different math operations. If one is specified, we
|
||||
# revert to using a CUTLASS 2.x kernel by using SM80-tagged kernels.
|
||||
cutlass.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.")
|
||||
cutlass_cppgen.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.")
|
||||
self._reset_options(80)
|
||||
self._reset_operations(reset_epilogue=False)
|
||||
elif self.current_cc == 90:
|
||||
@@ -266,7 +266,7 @@ class OperationBase:
|
||||
self._reset_operations()
|
||||
|
||||
def _elements_per_access(self):
|
||||
if self.op_class == cutlass.OpcodeClass.Simt:
|
||||
if self.op_class == cutlass_cppgen.OpcodeClass.Simt:
|
||||
return 1
|
||||
elif self._element_c != DataType.void:
|
||||
return 128 // DataTypeSize[self._element_c]
|
||||
@@ -286,7 +286,7 @@ class OperationBase:
|
||||
if self.current_cc == 90 and activation != identity:
|
||||
# CUTLASS 3.0 kernels in Python currently only support identity activation. If one requests a non-identity activation,
|
||||
# revert to using a CUTLASS 2.x kernel by using SM80-tagged kernels.
|
||||
cutlass.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.")
|
||||
cutlass_cppgen.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.")
|
||||
if self._element_c != self._element_d:
|
||||
raise Exception("CUTLASS 2.x kernels require element C to be the same as element D")
|
||||
self._reset_options(80)
|
||||
@@ -361,7 +361,7 @@ class OperationBase:
|
||||
"""
|
||||
if isinstance(act, tuple):
|
||||
if isinstance(act[0], str):
|
||||
act_fn = getattr(cutlass.backend.epilogue, act[0])
|
||||
act_fn = getattr(cutlass_cppgen.backend.epilogue, act[0])
|
||||
else:
|
||||
act_fn = act[0]
|
||||
self._reset_epilogue_functor_activation(act_fn)
|
||||
@@ -369,7 +369,7 @@ class OperationBase:
|
||||
self._activation = act[0]
|
||||
else:
|
||||
if isinstance(act, str):
|
||||
act = getattr(cutlass.backend.epilogue, act)
|
||||
act = getattr(cutlass_cppgen.backend.epilogue, act)
|
||||
self._reset_epilogue_functor_activation(act)
|
||||
self._activation = act
|
||||
|
||||
@@ -401,8 +401,8 @@ class OperationBase:
|
||||
td = datatypes.td_from_profiler_op(operation)
|
||||
# Filter invalid epilogue schedules
|
||||
if td.epilogue_schedule not in [
|
||||
cutlass.EpilogueScheduleType.TmaWarpSpecialized,
|
||||
cutlass.EpilogueScheduleType.TmaWarpSpecializedCooperative]:
|
||||
cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized,
|
||||
cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecializedCooperative]:
|
||||
continue
|
||||
epilogue_smem_bytes = self.epilogue_functor.get_smem_size(td)
|
||||
|
||||
@@ -427,4 +427,4 @@ class OperationBase:
|
||||
Steps that must be taken before caling `plan.run()`
|
||||
"""
|
||||
# Initialize the memory pool if, if not already done
|
||||
cutlass.get_memory_pool()
|
||||
cutlass_cppgen.get_memory_pool()
|
||||
|
||||
@@ -39,7 +39,7 @@ from cutlass_library import (
|
||||
ConvKind,
|
||||
LayoutType
|
||||
)
|
||||
from cutlass.backend.c_types import (
|
||||
from cutlass_cppgen.backend.c_types import (
|
||||
Conv2DProblemSize_,
|
||||
GemmCoord_,
|
||||
GemmCoordBatched_
|
||||
|
||||
@@ -30,7 +30,7 @@
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from cutlass.utils.check import (
|
||||
from cutlass_cppgen.utils.check import (
|
||||
alignment_or_default,
|
||||
calculate_smem_usage,
|
||||
calculate_smem_usage_per_stage,
|
||||
|
||||
@@ -38,8 +38,8 @@ import ctypes
|
||||
|
||||
from cutlass_library import DataTypeSize, OperationKind, SharedMemPerCC
|
||||
|
||||
import cutlass
|
||||
from cutlass.backend.library import TileDescription
|
||||
import cutlass_cppgen
|
||||
from cutlass_cppgen.backend.library import TileDescription
|
||||
|
||||
|
||||
def calculate_smem_usage_per_stage(td: TileDescription, operation_kind: OperationKind) -> int:
|
||||
@@ -82,8 +82,8 @@ def valid_stage_count(
|
||||
cc: int,
|
||||
kernel_cc: int,
|
||||
td: TileDescription,
|
||||
element_C: cutlass.DataType = None,
|
||||
element_D: cutlass.DataType = None,
|
||||
element_C: cutlass_cppgen.DataType = None,
|
||||
element_D: cutlass_cppgen.DataType = None,
|
||||
verbose: bool = True) -> tuple:
|
||||
"""
|
||||
Checks whether a device with `cc` supports the number of stages within `tile_description`, both
|
||||
@@ -96,9 +96,9 @@ def valid_stage_count(
|
||||
:param td: tile description to check
|
||||
:type td: TileDescription
|
||||
:param element_C: data type of operand C
|
||||
:type element_C: cutlass.DataType
|
||||
:type element_C: cutlass_cppgen.DataType
|
||||
:param element_D: data type of operand D
|
||||
:type element_D: cutlass.DataType
|
||||
:type element_D: cutlass_cppgen.DataType
|
||||
:param verbose: whether to log warnings
|
||||
:type verbose: bool
|
||||
|
||||
@@ -112,7 +112,7 @@ def valid_stage_count(
|
||||
# determines the stage count to use. Thus, all settings are valid in these scenarios.
|
||||
return (True, "")
|
||||
elif verbose:
|
||||
cutlass.logger.warning(
|
||||
cutlass_cppgen.logger.warning(
|
||||
"Setting an explicit stage count for SM90 kernels currently may "
|
||||
"result in compilation errors if the combination of tile shape, "
|
||||
"stage count, and shared memory requirement of the epilogue exceeds "
|
||||
@@ -188,9 +188,9 @@ def valid_cluster_shape(cc: int, cluster_shape: list) -> tuple:
|
||||
|
||||
def valid_schedule(
|
||||
cc: int,
|
||||
kernel_schedule: cutlass.KernelScheduleType,
|
||||
epilogue_schedule: cutlass.EpilogueScheduleType,
|
||||
tile_scheduler: cutlass.TileSchedulerType) -> tuple:
|
||||
kernel_schedule: cutlass_cppgen.KernelScheduleType,
|
||||
epilogue_schedule: cutlass_cppgen.EpilogueScheduleType,
|
||||
tile_scheduler: cutlass_cppgen.TileSchedulerType) -> tuple:
|
||||
"""
|
||||
Checks that the kernel and epilogue schedules passed in are a valid combination for
|
||||
a device of compute capability ``cc``.
|
||||
@@ -198,19 +198,19 @@ def valid_schedule(
|
||||
:param cc: compute capability of device in question
|
||||
:type cc: int
|
||||
:param kernel_schedule: kernel schedule type
|
||||
:type kernel_schedule: cutlass.KernelScheduleType
|
||||
:type kernel_schedule: cutlass_cppgen.KernelScheduleType
|
||||
:param epilogue_schedule: epilogue schedule type
|
||||
:type epilogue_schedule: cutlass.EpilogueScheduleType
|
||||
:type epilogue_schedule: cutlass_cppgen.EpilogueScheduleType
|
||||
:param tile_scheduler: tile scheduler type
|
||||
:type tile_scheduler: cutlass.TileSchedulerType
|
||||
:type tile_scheduler: cutlass_cppgen.TileSchedulerType
|
||||
|
||||
:return: tuple with the first element indicating whether the provided schedules are
|
||||
valid for the provided device and the second element being an error message
|
||||
:rtype: tuple
|
||||
"""
|
||||
kernel_auto = (kernel_schedule == cutlass.KernelScheduleType.ScheduleAuto)
|
||||
epilogue_auto = (epilogue_schedule == cutlass.EpilogueScheduleType.ScheduleAuto)
|
||||
tile_scheduler_default = (tile_scheduler == cutlass.TileSchedulerType.Default)
|
||||
kernel_auto = (kernel_schedule == cutlass_cppgen.KernelScheduleType.ScheduleAuto)
|
||||
epilogue_auto = (epilogue_schedule == cutlass_cppgen.EpilogueScheduleType.ScheduleAuto)
|
||||
tile_scheduler_default = (tile_scheduler == cutlass_cppgen.TileSchedulerType.Default)
|
||||
if cc < 90 and not (kernel_auto and epilogue_auto and tile_scheduler_default):
|
||||
return (False, "Non-default schedules are only supported on SM90 and beyond")
|
||||
|
||||
@@ -218,9 +218,9 @@ def valid_schedule(
|
||||
return (False, "Kernel and epilogue schedules must either both be auto or neither be auto")
|
||||
|
||||
if not tile_scheduler_default:
|
||||
cooperative_kernels = [cutlass.KernelScheduleType.TmaWarpSpecializedCooperative,
|
||||
cutlass.KernelScheduleType.CpAsyncWarpSpecializedCooperative]
|
||||
if (tile_scheduler == cutlass.TileSchedulerType.StreamK) and (kernel_schedule not in cooperative_kernels):
|
||||
cooperative_kernels = [cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedCooperative,
|
||||
cutlass_cppgen.KernelScheduleType.CpAsyncWarpSpecializedCooperative]
|
||||
if (tile_scheduler == cutlass_cppgen.TileSchedulerType.StreamK) and (kernel_schedule not in cooperative_kernels):
|
||||
return (False, "Stream-K tile scheduler is currently only supported with the cooperative kernel schedule")
|
||||
return (True, "")
|
||||
|
||||
|
||||
@@ -34,13 +34,13 @@
|
||||
Utility functions for converting between frontend datatypes and CUTLASS datatypes
|
||||
"""
|
||||
|
||||
import cutlass
|
||||
import cutlass_cppgen
|
||||
from cutlass_library import (
|
||||
DataTypeSize,
|
||||
MathOperation,
|
||||
MathInstruction
|
||||
)
|
||||
from cutlass.backend.library import (
|
||||
from cutlass_cppgen.backend.library import (
|
||||
TileDescription,
|
||||
)
|
||||
|
||||
@@ -62,11 +62,11 @@ def is_numpy_available():
|
||||
|
||||
numpy_available = True
|
||||
_library_to_numpy_dict = {
|
||||
cutlass.DataType.f16: np.float16,
|
||||
cutlass.DataType.f32: np.float32,
|
||||
cutlass.DataType.f64: np.float64,
|
||||
cutlass.DataType.s8: np.int8,
|
||||
cutlass.DataType.s32: np.int32,
|
||||
cutlass_cppgen.DataType.f16: np.float16,
|
||||
cutlass_cppgen.DataType.f32: np.float32,
|
||||
cutlass_cppgen.DataType.f64: np.float64,
|
||||
cutlass_cppgen.DataType.s8: np.int8,
|
||||
cutlass_cppgen.DataType.s32: np.int32,
|
||||
}
|
||||
except ImportError:
|
||||
numpy_available = False
|
||||
@@ -81,19 +81,19 @@ def is_numpy_tensor(inp) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def numpy_library_type(inp) -> cutlass.DataType:
|
||||
def numpy_library_type(inp) -> cutlass_cppgen.DataType:
|
||||
if is_numpy_available():
|
||||
import numpy as np
|
||||
if inp == np.float16:
|
||||
return cutlass.DataType.f16
|
||||
return cutlass_cppgen.DataType.f16
|
||||
elif inp == np.float32:
|
||||
return cutlass.DataType.f32
|
||||
return cutlass_cppgen.DataType.f32
|
||||
elif inp == np.float64:
|
||||
return cutlass.DataType.f64
|
||||
return cutlass_cppgen.DataType.f64
|
||||
elif inp == np.int8:
|
||||
return cutlass.DataType.s8
|
||||
return cutlass_cppgen.DataType.s8
|
||||
elif inp == np.int32:
|
||||
return cutlass.DataType.s32
|
||||
return cutlass_cppgen.DataType.s32
|
||||
return None
|
||||
|
||||
|
||||
@@ -109,11 +109,11 @@ def is_cupy_available():
|
||||
|
||||
cupy_available = True
|
||||
_library_to_cupy_dict = {
|
||||
cutlass.DataType.f16: cp.float16,
|
||||
cutlass.DataType.f32: cp.float32,
|
||||
cutlass.DataType.f64: cp.float64,
|
||||
cutlass.DataType.s8: cp.int8,
|
||||
cutlass.DataType.s32: cp.int32,
|
||||
cutlass_cppgen.DataType.f16: cp.float16,
|
||||
cutlass_cppgen.DataType.f32: cp.float32,
|
||||
cutlass_cppgen.DataType.f64: cp.float64,
|
||||
cutlass_cppgen.DataType.s8: cp.int8,
|
||||
cutlass_cppgen.DataType.s32: cp.int32,
|
||||
}
|
||||
except ImportError:
|
||||
cupy_available = False
|
||||
@@ -128,15 +128,15 @@ def is_cupy_tensor(inp) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def cupy_library_type(inp) -> cutlass.DataType:
|
||||
def cupy_library_type(inp) -> cutlass_cppgen.DataType:
|
||||
if is_cupy_available():
|
||||
import cupy as cp
|
||||
if inp == cp.float16:
|
||||
return cutlass.DataType.f16
|
||||
return cutlass_cppgen.DataType.f16
|
||||
elif inp == cp.float32:
|
||||
return cutlass.DataType.f32
|
||||
return cutlass_cppgen.DataType.f32
|
||||
elif inp == cp.float64:
|
||||
return cutlass.DataType.f64
|
||||
return cutlass_cppgen.DataType.f64
|
||||
return None
|
||||
|
||||
|
||||
@@ -152,29 +152,29 @@ def is_torch_available():
|
||||
|
||||
torch_available = True
|
||||
_torch_to_library_dict = {
|
||||
torch.half: cutlass.DataType.f16,
|
||||
torch.float16: cutlass.DataType.f16,
|
||||
torch.bfloat16: cutlass.DataType.bf16,
|
||||
torch.float: cutlass.DataType.f32,
|
||||
torch.float32: cutlass.DataType.f32,
|
||||
torch.double: cutlass.DataType.f64,
|
||||
torch.float64: cutlass.DataType.f64,
|
||||
torch.int8: cutlass.DataType.s8,
|
||||
torch.int32: cutlass.DataType.s32,
|
||||
torch.uint8: cutlass.DataType.u8,
|
||||
torch.half: cutlass_cppgen.DataType.f16,
|
||||
torch.float16: cutlass_cppgen.DataType.f16,
|
||||
torch.bfloat16: cutlass_cppgen.DataType.bf16,
|
||||
torch.float: cutlass_cppgen.DataType.f32,
|
||||
torch.float32: cutlass_cppgen.DataType.f32,
|
||||
torch.double: cutlass_cppgen.DataType.f64,
|
||||
torch.float64: cutlass_cppgen.DataType.f64,
|
||||
torch.int8: cutlass_cppgen.DataType.s8,
|
||||
torch.int32: cutlass_cppgen.DataType.s32,
|
||||
torch.uint8: cutlass_cppgen.DataType.u8,
|
||||
}
|
||||
|
||||
_library_to_torch_dict = {
|
||||
cutlass.DataType.f16: torch.half,
|
||||
cutlass.DataType.f16: torch.float16,
|
||||
cutlass.DataType.bf16: torch.bfloat16,
|
||||
cutlass.DataType.f32: torch.float,
|
||||
cutlass.DataType.f32: torch.float32,
|
||||
cutlass.DataType.f64: torch.double,
|
||||
cutlass.DataType.f64: torch.float64,
|
||||
cutlass.DataType.s8: torch.int8,
|
||||
cutlass.DataType.s32: torch.int32,
|
||||
cutlass.DataType.u8: torch.uint8,
|
||||
cutlass_cppgen.DataType.f16: torch.half,
|
||||
cutlass_cppgen.DataType.f16: torch.float16,
|
||||
cutlass_cppgen.DataType.bf16: torch.bfloat16,
|
||||
cutlass_cppgen.DataType.f32: torch.float,
|
||||
cutlass_cppgen.DataType.f32: torch.float32,
|
||||
cutlass_cppgen.DataType.f64: torch.double,
|
||||
cutlass_cppgen.DataType.f64: torch.float64,
|
||||
cutlass_cppgen.DataType.s8: torch.int8,
|
||||
cutlass_cppgen.DataType.s32: torch.int32,
|
||||
cutlass_cppgen.DataType.u8: torch.uint8,
|
||||
}
|
||||
|
||||
def possibly_add_type(torch_type_name, cutlass_type):
|
||||
@@ -184,8 +184,8 @@ def is_torch_available():
|
||||
_torch_to_library_dict[torch_type] = cutlass_type
|
||||
_library_to_torch_dict[cutlass_type] = torch_type
|
||||
|
||||
possibly_add_type("float8_e4m3fn", cutlass.DataType.e4m3)
|
||||
possibly_add_type("float8_e5m2", cutlass.DataType.e5m2)
|
||||
possibly_add_type("float8_e4m3fn", cutlass_cppgen.DataType.e4m3)
|
||||
possibly_add_type("float8_e5m2", cutlass_cppgen.DataType.e5m2)
|
||||
|
||||
except ImportError:
|
||||
torch_available = False
|
||||
@@ -201,7 +201,7 @@ def is_torch_tensor(inp) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def torch_library_type(inp) -> cutlass.DataType:
|
||||
def torch_library_type(inp) -> cutlass_cppgen.DataType:
|
||||
return _torch_to_library_dict.get(inp, None)
|
||||
|
||||
|
||||
@@ -222,17 +222,17 @@ def is_bfloat16_available():
|
||||
return bfloat16_available
|
||||
|
||||
|
||||
def bfloat16_library_type(inp) -> cutlass.DataType:
|
||||
def bfloat16_library_type(inp) -> cutlass_cppgen.DataType:
|
||||
if is_bfloat16_available():
|
||||
import bfloat16
|
||||
if inp == bfloat16.bfloat16:
|
||||
return cutlass.DataType.bf16
|
||||
return cutlass_cppgen.DataType.bf16
|
||||
|
||||
|
||||
def bfloat16_type(inp):
|
||||
if is_bfloat16_available():
|
||||
import bfloat16
|
||||
if inp == cutlass.DataType.bf16:
|
||||
if inp == cutlass_cppgen.DataType.bf16:
|
||||
return bfloat16.bfloat16
|
||||
|
||||
|
||||
@@ -256,15 +256,15 @@ def library_type(inp):
|
||||
def _tensor_from_numpy(np_tensor):
|
||||
dtype = library_type(np_tensor.dtype)
|
||||
if np_tensor.flags.c_contiguous:
|
||||
layout = cutlass.LayoutType.RowMajor
|
||||
layout = cutlass_cppgen.LayoutType.RowMajor
|
||||
elif np_tensor.flags.f_contiguous:
|
||||
layout = cutlass.LayoutType.ColumnMajor
|
||||
layout = cutlass_cppgen.LayoutType.ColumnMajor
|
||||
return (dtype, layout)
|
||||
|
||||
|
||||
def _tensor_from_torch(pt_tensor):
|
||||
dtype = library_type(pt_tensor.dtype)
|
||||
return (dtype, cutlass.LayoutType.RowMajor)
|
||||
return (dtype, cutlass_cppgen.LayoutType.RowMajor)
|
||||
|
||||
|
||||
def get_datatype_and_layout(tensor):
|
||||
@@ -273,7 +273,7 @@ def get_datatype_and_layout(tensor):
|
||||
elif is_torch_tensor(tensor):
|
||||
return _tensor_from_torch(tensor)
|
||||
elif isinstance(tensor, float) or isinstance(tensor, int):
|
||||
return (cutlass.DataType.f32, cutlass.LayoutType.RowMajor)
|
||||
return (cutlass_cppgen.DataType.f32, cutlass_cppgen.LayoutType.RowMajor)
|
||||
else:
|
||||
raise Exception(f"Unable to convert tensor of type {type(tensor)} to Python-bound CUTLASS datatype and layout.")
|
||||
|
||||
@@ -303,10 +303,10 @@ def backend_math_operation(math_op: MathOperation):
|
||||
return _math_operation_value_map[math_op.value]
|
||||
|
||||
|
||||
def construct_backend_td(td: cutlass.TileDescription,
|
||||
kernel_schedule: cutlass.KernelScheduleType,
|
||||
epilogue_schedule: cutlass.EpilogueScheduleType,
|
||||
tile_scheduler: cutlass.TileSchedulerType) -> TileDescription:
|
||||
def construct_backend_td(td: cutlass_cppgen.TileDescription,
|
||||
kernel_schedule: cutlass_cppgen.KernelScheduleType,
|
||||
epilogue_schedule: cutlass_cppgen.EpilogueScheduleType,
|
||||
tile_scheduler: cutlass_cppgen.TileSchedulerType) -> TileDescription:
|
||||
mi = td.math_instruction
|
||||
backend_mi = MathInstruction(
|
||||
mi.instruction_shape,
|
||||
@@ -328,7 +328,7 @@ def td_from_profiler_op(op) -> TileDescription:
|
||||
:param op: profiler Operation
|
||||
|
||||
:returns: backend TileDescription
|
||||
:rtype: cutlass.backend.TileDescription
|
||||
:rtype: cutlass_cppgen.backend.TileDescription
|
||||
"""
|
||||
kschedule = op.kernel_schedule if hasattr(op, 'kernel_schedule') else None
|
||||
eschedule = op.epilogue_schedule if hasattr(op, 'epilogue_schedule') else None
|
||||
@@ -341,10 +341,10 @@ def td_from_profiler_td(td: TileDescription) -> TileDescription:
|
||||
Converts the profiler's TileDescription into the backend TileDescription
|
||||
|
||||
:param td: profiler TileDescription
|
||||
:type td: cutlass.TileDescription
|
||||
:type td: cutlass_cppgen.TileDescription
|
||||
|
||||
:returns: backend TileDescription
|
||||
:rtype: cutlass.backend.TileDescription
|
||||
:rtype: cutlass_cppgen.backend.TileDescription
|
||||
"""
|
||||
return construct_backend_td(td, kernel_schedule=None, epilogue_schedule=None, tile_scheduler=None)
|
||||
|
||||
|
||||
@@ -37,16 +37,16 @@ Profiler based on the cuda events
|
||||
import re
|
||||
import subprocess
|
||||
|
||||
from cutlass.utils.lazy_import import lazy_import
|
||||
from cutlass_cppgen.utils.lazy_import import lazy_import
|
||||
cuda = lazy_import("cuda.cuda")
|
||||
cudart = lazy_import("cuda.cudart")
|
||||
import numpy as np
|
||||
|
||||
from cutlass import CUTLASS_PATH
|
||||
from cutlass.backend.library import DataTypeSize
|
||||
from cutlass.op.op import OperationBase
|
||||
from cutlass.shape import GemmCoord
|
||||
from cutlass.utils.datatypes import is_numpy_tensor
|
||||
from cutlass_cppgen import CUTLASS_PATH
|
||||
from cutlass_cppgen.backend.library import DataTypeSize
|
||||
from cutlass_cppgen.op.op import OperationBase
|
||||
from cutlass_cppgen.shape import GemmCoord
|
||||
from cutlass_cppgen.utils.datatypes import is_numpy_tensor
|
||||
|
||||
|
||||
class GpuTimer:
|
||||
|
||||
@@ -49,7 +49,6 @@ from . import rank_2k_operation
|
||||
from . import rank_k_operation
|
||||
from . import symm_operation
|
||||
from . import trmm_operation
|
||||
|
||||
# Make enum types from library.py accessible via cutlass_library.*
|
||||
from .library import *
|
||||
|
||||
|
||||
@@ -279,7 +279,7 @@ def _computeFlopsPerByte(operation, m, n, k, batch_count=1, beta=0.0, num_groups
|
||||
def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
):
|
||||
# For functional testing, we prefer to run reference computing on device if any
|
||||
reference_device_archs = ["100a"]
|
||||
reference_device_archs = ["100a", "103a"]
|
||||
run_reference_on_device = True if arch in reference_device_archs and mode in ["functional_L0", "functional_L1"] else False
|
||||
profiler_flags_for_verification = "device" if run_reference_on_device else "host"
|
||||
|
||||
@@ -287,7 +287,7 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
# TODO: randomize beta values for wider coverage
|
||||
beta_values = [0.5]
|
||||
|
||||
is_supported_arch = (arch in ["100a", "100f", "101a", "101f", "120a", "120f"])
|
||||
is_supported_arch = (arch in ["100a", "100f", "101a", "101f", "103a", "110a", "110f", "120a", "120f", "121a", "121f"])
|
||||
|
||||
is_runtime_datatype_enabled = mode == "functional_L0" and is_supported_arch
|
||||
|
||||
@@ -306,6 +306,10 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
'bf16gemm_f32_f32_f32_f32_f32',
|
||||
]
|
||||
|
||||
exclude_archs = arch not in ("103a")
|
||||
if exclude_archs:
|
||||
sm100_mma_data_type_general.append('gemm_s8_s8_s32_s8_s8')
|
||||
|
||||
sm100_mma_data_type_runtime_dtype = [
|
||||
'gemm.*f4_f4_f32_f32_f32',
|
||||
'gemm.*f6_f6_f32_f32_f32',
|
||||
@@ -344,6 +348,11 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
'gemm.*ue8m0xf6_ue8m0xf6_f32_f16_ue8m0xe3m2',
|
||||
]
|
||||
|
||||
sm103_block_scaled_data_type = [
|
||||
'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2',
|
||||
'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_ue8m0xe2m1',
|
||||
]
|
||||
|
||||
block_scaled_cluster_size = [
|
||||
'4x4x1', '2x1x1',
|
||||
'0x0x1' # dynamic cluster
|
||||
@@ -354,6 +363,9 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
|
||||
block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
|
||||
|
||||
sm103_block_scaled_filter_regex_1sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
|
||||
sm103_block_scaled_filter_regex_2sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
|
||||
|
||||
if arch in ["100a", "100f"]:
|
||||
kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \
|
||||
f"({sm100_mma_filter_regex_2sm})|" \
|
||||
@@ -361,15 +373,23 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
f"({sm100_mma_filter_regex_2sm_runtime})|" \
|
||||
f"({block_scaled_filter_regex_1sm})|" \
|
||||
f"({block_scaled_filter_regex_2sm})"
|
||||
elif arch in ["101a", "101f",
|
||||
]:
|
||||
elif arch in ["101a", "101f", "110a", "110f"]:
|
||||
kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \
|
||||
f"({sm100_mma_filter_regex_2sm})|" \
|
||||
f"({sm100_mma_filter_regex_1sm_runtime})|" \
|
||||
f"({sm100_mma_filter_regex_2sm_runtime})|" \
|
||||
f"({block_scaled_filter_regex_1sm})|" \
|
||||
f"({block_scaled_filter_regex_2sm})"
|
||||
elif arch in ["120a", "120f"]:
|
||||
elif arch in ["103a"]:
|
||||
kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \
|
||||
f"({sm100_mma_filter_regex_2sm})|" \
|
||||
f"({sm100_mma_filter_regex_1sm_runtime})|" \
|
||||
f"({sm100_mma_filter_regex_2sm_runtime})|" \
|
||||
f"({block_scaled_filter_regex_1sm})|" \
|
||||
f"({block_scaled_filter_regex_2sm})|" \
|
||||
f"({sm103_block_scaled_filter_regex_1sm})|" \
|
||||
f"({sm103_block_scaled_filter_regex_2sm})"
|
||||
elif arch in ["120a", "120f", "121a", "121f"]:
|
||||
|
||||
# blockscaled sm120_mma kernels
|
||||
blockscaled_sm120_mma_kernel_cta_tiles = [
|
||||
@@ -384,7 +404,7 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
|
||||
kernel_filter = f"({filter_regex_blockscaled_sm120_mma})"
|
||||
else:
|
||||
error_message = "unsupported arch, only support sm100a, sm100f, sm101a, sm101f, sm120a, sm120f"
|
||||
error_message = "unsupported arch, only support sm100a, sm100f, sm101a, sm101f, sm110a, sm110f, sm103a, sm120a, sm120f, sm121a, sm121f"
|
||||
raise Exception(error_message)
|
||||
|
||||
elif mode == "functional_L1":
|
||||
@@ -403,16 +423,27 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
'ue8m0xe2m3_ue8m0xe2m3_f32_f16_ue8m0xe3m2',
|
||||
]
|
||||
|
||||
block_scaled_cluster_size = ['4x4x1', '2x1x1', '0x0x1']
|
||||
sm103_block_scaled_data_type = [
|
||||
'ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2',
|
||||
'ue8m0xe2m1_ue8m0xe2m1_f32_f16_ue8m0xe2m1',
|
||||
]
|
||||
|
||||
block_scaled_cluster_size = ['0x0x1']
|
||||
block_scaled_layouts = ['tnt']
|
||||
|
||||
# regex list must be in kernel procedural name order
|
||||
block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
|
||||
block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
|
||||
|
||||
sm103_block_scaled_filter_regex_1sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
|
||||
sm103_block_scaled_filter_regex_2sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
|
||||
|
||||
filter_regex_sm100_mma = f"({sm100_mma_filter_regex_1sm})|" \
|
||||
f"({sm100_mma_filter_regex_2sm})|" \
|
||||
f"({block_scaled_filter_regex_1sm})|" \
|
||||
f"({block_scaled_filter_regex_2sm})"
|
||||
f"({block_scaled_filter_regex_2sm})" \
|
||||
f"({sm103_block_scaled_filter_regex_1sm})|" \
|
||||
f"({sm103_block_scaled_filter_regex_2sm})"
|
||||
# CTA tiles for sm120 MMA - only run one tile size to reduce build/test times
|
||||
sm120_mma_kernel_cta_tiles = [
|
||||
# h1688, s1688, i16832, i8816
|
||||
@@ -449,7 +480,10 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
|
||||
|
||||
problem_waves = [0.5, 1.25, 2.5]
|
||||
|
||||
kernel_filter = f"({filter_regex_sm100_mma})|({filter_regex_sm120_mma})"
|
||||
if arch in ["120a", "120f", "121a", "121f"]:
|
||||
kernel_filter = f"({filter_regex_sm120_mma})"
|
||||
else:
|
||||
kernel_filter = f"({filter_regex_sm100_mma})"
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
|
||||
@@ -341,7 +341,7 @@ class GemmOperation:
|
||||
Get the tile shape passed to the collective builder.
|
||||
On Blackwell, this is different than the operation.tile_description.tile_shape.
|
||||
"""
|
||||
is_sm100_kernel = (self.arch == 100)
|
||||
is_sm100_kernel = (self.arch == 100 or self.arch == 103)
|
||||
if not is_sm100_kernel:
|
||||
return self.tile_description.tile_shape
|
||||
|
||||
@@ -995,6 +995,24 @@ ${compile_guard_end}
|
||||
epi_tile_mn = "cute::Shape<cute::_128,cute::_64>"
|
||||
if not is_no_smem_epilogue:
|
||||
epilogue_schedule_type = EpilogueScheduleTag[to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized2Sm, grouped)]
|
||||
if cta_n == 256 and operation.kernel_schedule == KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103:
|
||||
epi_tile_mn = "cute::Shape<cute::_128,cute::_64>"
|
||||
if not is_no_smem_epilogue:
|
||||
epilogue_schedule_type = EpilogueScheduleTag[EpilogueScheduleType.TmaWarpSpecialized1Sm]
|
||||
if cta_n == 256 and operation.kernel_schedule == KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103:
|
||||
epi_tile_mn = "cute::Shape<cute::_128,cute::_64>"
|
||||
if not is_no_smem_epilogue:
|
||||
epilogue_schedule_type = EpilogueScheduleTag[EpilogueScheduleType.TmaWarpSpecialized2Sm]
|
||||
|
||||
if cta_n == 256 and operation.kernel_schedule == KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103:
|
||||
epi_tile_mn = "cute::Shape<cute::_128,cute::_64>"
|
||||
if not is_no_smem_epilogue:
|
||||
epilogue_schedule_type = EpilogueScheduleTag[EpilogueScheduleType.TmaWarpSpecialized1Sm]
|
||||
if cta_n == 256 and operation.kernel_schedule == KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103:
|
||||
epi_tile_mn = "cute::Shape<cute::_128,cute::_64>"
|
||||
if not is_no_smem_epilogue:
|
||||
epilogue_schedule_type = EpilogueScheduleTag[EpilogueScheduleType.TmaWarpSpecialized2Sm]
|
||||
|
||||
element_a = f'cute::tuple<{str(element_a)},{str(DataTypeTag[operation.ScaleFactorA])}>'
|
||||
element_b = f'cute::tuple<{str(element_b)},{str(DataTypeTag[operation.ScaleFactorB])}>'
|
||||
|
||||
|
||||
@@ -90,10 +90,12 @@ try:
|
||||
raise ImportError("Disabling attempt to import cutlass_library")
|
||||
from cutlass_library.library import *
|
||||
from cutlass_library.manifest import *
|
||||
from cutlass_library.heuristics import *
|
||||
from cutlass_library.emit_kernel_listing import emit_gemm_kernel_testlist
|
||||
except ImportError:
|
||||
from library import *
|
||||
from manifest import *
|
||||
from heuristics import *
|
||||
from emit_kernel_listing import emit_gemm_kernel_testlist
|
||||
###################################################################################################
|
||||
|
||||
@@ -112,6 +114,10 @@ def CudaToolkitVersionSatisfies(semantic_ver_string, major, minor, patch = 0):
|
||||
cuda_version.append(x)
|
||||
return cuda_version >= [major, minor, patch]
|
||||
|
||||
# From cuda 13.0, Thor SM is renumbered from 101 to 110
|
||||
def ThorSMRenumbering(cuda_version):
|
||||
return 110 if CudaToolkitVersionSatisfies(cuda_version, 13, 0) else 101
|
||||
|
||||
###################################################################################################
|
||||
###################################################################################################
|
||||
|
||||
@@ -6768,9 +6774,11 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version):
|
||||
},
|
||||
]
|
||||
|
||||
thor_sm = 101
|
||||
thor_sm = ThorSMRenumbering(cuda_version)
|
||||
|
||||
min_cc = 100
|
||||
max_cc = thor_sm
|
||||
|
||||
math_instructions_1sm = [
|
||||
# tf32 -> f32
|
||||
MathInstruction(
|
||||
@@ -6887,7 +6895,8 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK
|
||||
[[LayoutType.RowMajor, 8], [LayoutType.RowMajor, 8], [LayoutType.RowMajor, 0]],
|
||||
]
|
||||
|
||||
thor_sm = 101
|
||||
thor_sm = ThorSMRenumbering(cuda_version)
|
||||
|
||||
min_cc = 100
|
||||
max_cc = thor_sm
|
||||
grouped = is_grouped(gemm_kind)
|
||||
@@ -7202,9 +7211,11 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK
|
||||
[[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 0]],
|
||||
]
|
||||
|
||||
thor_sm = 101
|
||||
thor_sm = ThorSMRenumbering(cuda_version)
|
||||
|
||||
min_cc = 100
|
||||
max_cc = thor_sm
|
||||
|
||||
epi_type = DataType.f32
|
||||
grouped = is_grouped(gemm_kind)
|
||||
|
||||
@@ -7889,9 +7900,11 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version):
|
||||
TileSchedulerType.Default, TileSchedulerType.StreamK
|
||||
]
|
||||
|
||||
thor_sm = 101
|
||||
thor_sm = ThorSMRenumbering(cuda_version)
|
||||
|
||||
min_cc = 100
|
||||
max_cc = thor_sm
|
||||
|
||||
epi_type = DataType.f32
|
||||
|
||||
math_instructions_1sm = []
|
||||
@@ -8092,6 +8105,8 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
|
||||
if not CudaToolkitVersionSatisfies(cuda_version, 12, 8):
|
||||
return
|
||||
|
||||
grouped = is_grouped(gemm_kind)
|
||||
|
||||
layouts = [
|
||||
[[LayoutType.RowMajor, 128], [LayoutType.ColumnMajor, 128], [LayoutType.RowMajor, 0]],
|
||||
[[LayoutType.ColumnMajor, 128], [LayoutType.RowMajor, 128], [LayoutType.RowMajor, 0]],
|
||||
@@ -8120,14 +8135,16 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
|
||||
def tile_schedulers(sfdtype):
|
||||
# Only use the stream-K scheduler for non-void SFD to limit kernel count. When SFD is void,
|
||||
# the epilogue is the traditional linear combination, for which we already have tests with stream-K.
|
||||
if sfdtype["type"] == DataType.void:
|
||||
if sfdtype["type"] == DataType.void or grouped:
|
||||
return [TileSchedulerType.Default]
|
||||
else:
|
||||
return [TileSchedulerType.Default, TileSchedulerType.StreamK]
|
||||
|
||||
thor_sm = 101
|
||||
thor_sm = ThorSMRenumbering(cuda_version)
|
||||
|
||||
min_cc = 100
|
||||
max_cc = thor_sm
|
||||
|
||||
epi_type = DataType.f32
|
||||
|
||||
math_instructions_1sm = []
|
||||
@@ -8209,6 +8226,16 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
|
||||
"sf_type" : math_inst.element_scale_factor,
|
||||
"sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None}
|
||||
},
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
"c_type" : DataType.void,
|
||||
"d_type" : DataType.bf16,
|
||||
"acc_type" : math_inst.element_accumulator,
|
||||
"epi_type" : epi_type,
|
||||
"sf_type" : math_inst.element_scale_factor,
|
||||
"sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None}
|
||||
},
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
@@ -8246,7 +8273,7 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
|
||||
|
||||
for data_type in data_types:
|
||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
|
||||
[[KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]]
|
||||
[[to_grouped_schedule(KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100, grouped), to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped)]]
|
||||
, tile_schedulers = tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind)
|
||||
|
||||
cluster_shapes_2sm = [
|
||||
@@ -8288,6 +8315,16 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
|
||||
"sf_type" : math_inst.element_scale_factor,
|
||||
"sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None}
|
||||
},
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
"c_type" : DataType.void,
|
||||
"d_type" : DataType.bf16,
|
||||
"acc_type" : math_inst.element_accumulator,
|
||||
"epi_type" : epi_type,
|
||||
"sf_type" : math_inst.element_scale_factor,
|
||||
"sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None}
|
||||
},
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
@@ -8346,7 +8383,11 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
|
||||
if tile.threadblock_shape[2] // tile.cluster_shape[2] % layout[1][1] != 0:
|
||||
continue
|
||||
|
||||
if math_inst.instruction_shape[0] == 128:
|
||||
if grouped:
|
||||
CreateGemmUniversal3xOperator(manifest, [layout], [tile], [data_type],
|
||||
[[to_grouped_schedule(KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100, grouped), to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized2Sm, grouped)]]
|
||||
, tile_schedulers = tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind)
|
||||
elif math_inst.instruction_shape[0] == 128:
|
||||
CreateGemmUniversal3xOperator(manifest, [layout], [tile], [data_type],
|
||||
[[KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm]]
|
||||
, tile_schedulers = tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind)
|
||||
@@ -8396,9 +8437,11 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
|
||||
else:
|
||||
return [TileSchedulerType.Default, TileSchedulerType.StreamK]
|
||||
|
||||
thor_sm = 101
|
||||
thor_sm = ThorSMRenumbering(cuda_version)
|
||||
|
||||
min_cc = 100
|
||||
max_cc = thor_sm
|
||||
|
||||
epi_type = DataType.f32
|
||||
|
||||
math_instructions_1sm = []
|
||||
@@ -8496,6 +8539,16 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
|
||||
"sf_type" : math_inst.element_scale_factor,
|
||||
"sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None}
|
||||
},
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
"c_type" : DataType.bf16,
|
||||
"d_type" : DataType.bf16,
|
||||
"acc_type" : math_inst.element_accumulator,
|
||||
"epi_type" : epi_type,
|
||||
"sf_type" : math_inst.element_scale_factor,
|
||||
"sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None}
|
||||
},
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
@@ -8625,6 +8678,16 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
|
||||
"sf_type" : math_inst.element_scale_factor,
|
||||
"sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None}
|
||||
},
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
"c_type" : DataType.bf16,
|
||||
"d_type" : DataType.bf16,
|
||||
"acc_type" : math_inst.element_accumulator,
|
||||
"epi_type" : epi_type,
|
||||
"sf_type" : math_inst.element_scale_factor,
|
||||
"sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None}
|
||||
},
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
@@ -8715,6 +8778,230 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
|
||||
|
||||
|
||||
|
||||
def GenerateSM103_TensorOp_fp4_ultra_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.BlockScaledUniversal3x):
|
||||
# SM100 MMA with F4 + block scale
|
||||
if not CudaToolkitVersionSatisfies(cuda_version, 12, 8):
|
||||
return
|
||||
|
||||
# layouts for ABC and their alignments.
|
||||
layouts = [
|
||||
[[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 32], [LayoutType.RowMajor, 0]],
|
||||
]
|
||||
|
||||
instruction_sizes_1sm = [
|
||||
[128, 128, 96],
|
||||
]
|
||||
|
||||
instruction_sizes_2sm = [
|
||||
[256, 128, 96],
|
||||
]
|
||||
|
||||
ab_types = [
|
||||
DataType.e2m1,
|
||||
]
|
||||
|
||||
acc_types = [ DataType.f32 ] # Accumulator is always 32 bits for block scaled MMA instructions
|
||||
|
||||
min_cc = 103
|
||||
max_cc = 103
|
||||
epi_type = DataType.f32
|
||||
|
||||
math_instructions_1sm = []
|
||||
|
||||
is_runtime_datatype = lambda runtime_datatype: runtime_datatype in (DataType.f4, DataType.f6, DataType.f8)
|
||||
|
||||
for instr_size, a_type, b_type, acc_type in product(instruction_sizes_1sm, ab_types, ab_types, acc_types):
|
||||
is_runtime_datatype_a = is_runtime_datatype(a_type)
|
||||
is_runtime_datatype_b = is_runtime_datatype(b_type)
|
||||
|
||||
# A/B datatypes should be both static or dynamic
|
||||
if (is_runtime_datatype_a != is_runtime_datatype_b):
|
||||
continue
|
||||
|
||||
math_instructions_1sm.append(
|
||||
MathInstruction(
|
||||
instr_size,
|
||||
a_type, b_type, acc_type,
|
||||
OpcodeClass.BlockScaledTensorOp,
|
||||
MathOperation.multiply_add,
|
||||
DataType.ue8m0) # UE8M0 scale factor
|
||||
)
|
||||
|
||||
math_instructions_2sm = []
|
||||
|
||||
for instr_size, a_type, b_type, acc_type in product(instruction_sizes_2sm, ab_types, ab_types, acc_types):
|
||||
is_runtime_datatype_a = is_runtime_datatype(a_type)
|
||||
is_runtime_datatype_b = is_runtime_datatype(b_type)
|
||||
|
||||
# A/B datatypes should be both static or dynamic
|
||||
if (is_runtime_datatype_a != is_runtime_datatype_b):
|
||||
continue
|
||||
|
||||
math_instructions_2sm.append(
|
||||
MathInstruction(
|
||||
instr_size,
|
||||
a_type, b_type, acc_type,
|
||||
OpcodeClass.BlockScaledTensorOp,
|
||||
MathOperation.multiply_add,
|
||||
DataType.ue8m0) # UE8M0 scale factor
|
||||
)
|
||||
|
||||
cluster_shapes_1sm = [
|
||||
[1,1,1],
|
||||
# [1,2,1],
|
||||
[2,1,1],
|
||||
# [1,4,1],
|
||||
[4,4,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
# 1xSM MMA kernels
|
||||
for math_inst in math_instructions_1sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in cluster_shapes_1sm:
|
||||
multiplier_1sm = cluster_shape
|
||||
tile_descriptions.append(
|
||||
TileDescription([
|
||||
math_inst.instruction_shape[0] * multiplier_1sm[0],
|
||||
math_inst.instruction_shape[1] * multiplier_1sm[1],
|
||||
768],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape))
|
||||
|
||||
data_types = [
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
"c_type" : DataType.void,
|
||||
"d_type" : DataType.f32,
|
||||
"acc_type" : math_inst.element_accumulator,
|
||||
"epi_type" : epi_type,
|
||||
"sf_type" : math_inst.element_scale_factor,
|
||||
"sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None}
|
||||
},
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
"c_type" : DataType.bf16,
|
||||
"d_type" : DataType.bf16,
|
||||
"acc_type" : math_inst.element_accumulator,
|
||||
"epi_type" : epi_type,
|
||||
"sf_type" : math_inst.element_scale_factor,
|
||||
"sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None}
|
||||
},
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
"c_type" : DataType.void,
|
||||
"d_type" : DataType.e2m1,
|
||||
"acc_type" : math_inst.element_accumulator,
|
||||
"epi_type" : epi_type,
|
||||
"sf_type" : math_inst.element_scale_factor,
|
||||
"sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor}
|
||||
},
|
||||
]
|
||||
|
||||
for layout in layouts:
|
||||
for data_type in data_types:
|
||||
# Set alignment d based on Destination format.
|
||||
if DataTypeSize[data_type["c_type"]] == 0 :
|
||||
layout[2][1] = 256 // DataTypeSize[data_type["d_type"]]
|
||||
else:
|
||||
layout[2][1] = min(256 // DataTypeSize[data_type["d_type"]], 256 // DataTypeSize[data_type["c_type"]])
|
||||
|
||||
if data_type["sfd_type"]["type"] != DataType.void and (data_type["d_type"] == DataType.e2m1):
|
||||
data_type["sfd_type"]["layout"] = layout[2][0] # For FP4 output , the scalefactor layout is same layout as D layout.
|
||||
# E2M1 x E2M1, vector size 32, E8
|
||||
isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1
|
||||
|
||||
fp4_schedule = [KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103, EpilogueScheduleType.NoSmemWarpSpecialized1Sm]
|
||||
fp4_schedule_disable_prefetch = [KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch, EpilogueScheduleType.NoSmemWarpSpecialized1Sm]
|
||||
fp4_schedule_enable_prefetch = [KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch, EpilogueScheduleType.NoSmemWarpSpecialized1Sm]
|
||||
# For FP4 inputs
|
||||
if isFp4:
|
||||
CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, [fp4_schedule, fp4_schedule_disable_prefetch
|
||||
,fp4_schedule_enable_prefetch
|
||||
]
|
||||
, gemm_kind=gemm_kind
|
||||
)
|
||||
|
||||
cluster_shapes_2sm = [
|
||||
[2,1,1],
|
||||
# [2,2,1],
|
||||
# [2,4,1],
|
||||
[4,1,1],
|
||||
# [4,2,1],
|
||||
[4,4,1]
|
||||
, DynamicClusterShape
|
||||
]
|
||||
|
||||
for math_inst in math_instructions_2sm:
|
||||
tile_descriptions = []
|
||||
for cluster_shape in cluster_shapes_2sm:
|
||||
multiplier_2sm = (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2])
|
||||
tile_descriptions.append(
|
||||
TileDescription([
|
||||
math_inst.instruction_shape[0] * multiplier_2sm[0],
|
||||
math_inst.instruction_shape[1] * multiplier_2sm[1],
|
||||
math_inst.instruction_shape[2] * 8 * multiplier_2sm[2]],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape))
|
||||
|
||||
data_types = [
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
"c_type" : DataType.void,
|
||||
"d_type" : DataType.f32,
|
||||
"acc_type" : math_inst.element_accumulator,
|
||||
"epi_type" : epi_type,
|
||||
"sf_type" : math_inst.element_scale_factor,
|
||||
"sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None}
|
||||
},
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
"c_type" : DataType.bf16,
|
||||
"d_type" : DataType.bf16,
|
||||
"acc_type" : math_inst.element_accumulator,
|
||||
"epi_type" : epi_type,
|
||||
"sf_type" : math_inst.element_scale_factor,
|
||||
"sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None}
|
||||
},
|
||||
|
||||
{
|
||||
"a_type" : math_inst.element_a,
|
||||
"b_type" : math_inst.element_b,
|
||||
"c_type" : DataType.void,
|
||||
"d_type" : DataType.e2m1,
|
||||
"acc_type" : math_inst.element_accumulator,
|
||||
"epi_type" : epi_type,
|
||||
"sf_type" : math_inst.element_scale_factor,
|
||||
"sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor}
|
||||
},
|
||||
]
|
||||
|
||||
for layout in layouts:
|
||||
for data_type in data_types:
|
||||
# Set alignment d based on Destination format.
|
||||
if DataTypeSize[data_type["c_type"]] == 0 :
|
||||
layout[2][1] = 256 // DataTypeSize[data_type["d_type"]]
|
||||
else:
|
||||
layout[2][1] = min(256 // DataTypeSize[data_type["d_type"]], 256 // DataTypeSize[data_type["c_type"]])
|
||||
|
||||
if data_type["sfd_type"]["type"] != DataType.void and (data_type["d_type"] == DataType.e2m1):
|
||||
data_type["sfd_type"]["layout"] = layout[2][0] # For FP4 output , the scalefactor layout is same layout as D layout.
|
||||
# E2M1 x E2M1, vector size 32, E8
|
||||
isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1
|
||||
|
||||
fp4_schedule = [KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103, EpilogueScheduleType.NoSmemWarpSpecialized2Sm]
|
||||
fp4_schedule_disable_prefetch = [KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch, EpilogueScheduleType.NoSmemWarpSpecialized2Sm]
|
||||
fp4_schedule_enable_prefetch = [KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch, EpilogueScheduleType.NoSmemWarpSpecialized2Sm]
|
||||
# For FP4 inputs
|
||||
if isFp4:
|
||||
CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, [fp4_schedule, fp4_schedule_disable_prefetch
|
||||
,fp4_schedule_enable_prefetch
|
||||
]
|
||||
, gemm_kind=gemm_kind
|
||||
)
|
||||
|
||||
def GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version):
|
||||
if not CudaToolkitVersionSatisfies(cuda_version, 12, 8):
|
||||
@@ -8732,7 +9019,8 @@ def GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version):
|
||||
[[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 0]],
|
||||
]
|
||||
|
||||
thor_sm = 101
|
||||
thor_sm = ThorSMRenumbering(cuda_version)
|
||||
|
||||
min_cc = 100
|
||||
max_cc = thor_sm
|
||||
|
||||
@@ -8948,9 +9236,11 @@ def GenerateSM100_SparseTensorOp_32b_UMMA_gemm(manifest, cuda_version):
|
||||
[[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]],
|
||||
]
|
||||
|
||||
thor_sm = 101
|
||||
thor_sm = ThorSMRenumbering(cuda_version)
|
||||
|
||||
min_cc = 100
|
||||
max_cc = thor_sm
|
||||
|
||||
tile_schedulers = [
|
||||
TileSchedulerType.Default,
|
||||
]
|
||||
@@ -9074,9 +9364,11 @@ def GenerateSM100_SparseTensorOp_16b_UMMA_gemm(manifest, cuda_version):
|
||||
[[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]],
|
||||
]
|
||||
|
||||
thor_sm = 101
|
||||
thor_sm = ThorSMRenumbering(cuda_version)
|
||||
|
||||
min_cc = 100
|
||||
max_cc = thor_sm
|
||||
|
||||
tile_schedulers = [
|
||||
TileSchedulerType.Default,
|
||||
]
|
||||
@@ -9200,7 +9492,8 @@ def GenerateSM100_SparseTensorOp_int8_UMMA_gemm(manifest, cuda_version):
|
||||
[[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]],
|
||||
]
|
||||
|
||||
thor_sm = 101
|
||||
thor_sm = ThorSMRenumbering(cuda_version)
|
||||
|
||||
min_cc = 100
|
||||
max_cc = thor_sm
|
||||
|
||||
@@ -9326,9 +9619,11 @@ def GenerateSM100_SparseTensorOp_fp8_UMMA_gemm(manifest, cuda_version):
|
||||
[[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]],
|
||||
]
|
||||
|
||||
thor_sm = 101
|
||||
thor_sm = ThorSMRenumbering(cuda_version)
|
||||
|
||||
min_cc = 100
|
||||
max_cc = thor_sm
|
||||
|
||||
tile_schedulers = [
|
||||
TileSchedulerType.Default,
|
||||
]
|
||||
@@ -9465,9 +9760,11 @@ def GenerateSM100_SparseTensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version):
|
||||
[[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]],
|
||||
]
|
||||
|
||||
thor_sm = 101
|
||||
thor_sm = ThorSMRenumbering(cuda_version)
|
||||
|
||||
min_cc = 100
|
||||
max_cc = thor_sm
|
||||
|
||||
tile_schedulers = [
|
||||
TileSchedulerType.Default,
|
||||
]
|
||||
@@ -9678,9 +9975,11 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm_stream_k(manifest, cuda_version):
|
||||
}
|
||||
]
|
||||
|
||||
thor_sm = 101
|
||||
thor_sm = ThorSMRenumbering(cuda_version)
|
||||
|
||||
min_cc = 100
|
||||
max_cc = thor_sm
|
||||
|
||||
math_instructions_1sm = [
|
||||
MathInstruction(
|
||||
[128, 256, 8],
|
||||
@@ -9772,9 +10071,11 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm_stream_k(manifest, cuda_version):
|
||||
[[LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.RowMajor, 0]],
|
||||
]
|
||||
|
||||
thor_sm = 101
|
||||
thor_sm = ThorSMRenumbering(cuda_version)
|
||||
|
||||
min_cc = 100
|
||||
max_cc = thor_sm
|
||||
|
||||
math_instructions_1sm = [
|
||||
MathInstruction(
|
||||
[128, 256, 16],
|
||||
@@ -9934,9 +10235,11 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm_stream_k(manifest, cuda_version):
|
||||
[[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 0]],
|
||||
]
|
||||
|
||||
thor_sm = 101
|
||||
thor_sm = ThorSMRenumbering(cuda_version)
|
||||
|
||||
min_cc = 100
|
||||
max_cc = thor_sm
|
||||
|
||||
epi_type = DataType.f32
|
||||
|
||||
math_instructions_1sm = [
|
||||
@@ -10084,7 +10387,8 @@ def GenerateSM100_TensorOp_16b_UMMA_conv3x(manifest, cuda_version,
|
||||
if not CudaToolkitVersionSatisfies(cuda_version, 12, 0):
|
||||
return
|
||||
|
||||
thor_sm = 101
|
||||
thor_sm = ThorSMRenumbering(cuda_version)
|
||||
|
||||
minimum_compute_capability = 100
|
||||
maximum_compute_capability = thor_sm
|
||||
|
||||
@@ -10238,7 +10542,8 @@ def GenerateSM100_TensorOp_fp8_UMMA_conv3x(manifest, cuda_version,
|
||||
if not CudaToolkitVersionSatisfies(cuda_version, 12, 0):
|
||||
return
|
||||
|
||||
thor_sm = 101
|
||||
thor_sm = ThorSMRenumbering(cuda_version)
|
||||
|
||||
minimum_compute_capability = 100
|
||||
maximum_compute_capability = thor_sm
|
||||
|
||||
@@ -10422,7 +10727,7 @@ def GenerateSM120_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud
|
||||
return [TileSchedulerType.Default, TileSchedulerType.StreamK]
|
||||
|
||||
min_cc = 120
|
||||
max_cc = 120
|
||||
max_cc = 121
|
||||
|
||||
epi_type = DataType.f32
|
||||
|
||||
@@ -10567,7 +10872,7 @@ def GenerateSM120_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio
|
||||
return [TileSchedulerType.Default, TileSchedulerType.StreamK]
|
||||
|
||||
min_cc = 120
|
||||
max_cc = 120
|
||||
max_cc = 121
|
||||
|
||||
epi_type = DataType.f32
|
||||
|
||||
@@ -10720,7 +11025,7 @@ def GenerateSM120_Sparse_TensorOp_gemm(manifest, cuda_version):
|
||||
return [TileSchedulerType.Default]
|
||||
|
||||
min_cc = 120
|
||||
max_cc = 120
|
||||
max_cc = 121
|
||||
|
||||
kernel_schedules = [
|
||||
KernelScheduleType.F8f6f4SparseTmaWarpSpecializedCooperativeSm120,
|
||||
@@ -10840,7 +11145,7 @@ def GenerateSM120_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version,
|
||||
return [TileSchedulerType.Default]
|
||||
|
||||
min_cc = 120
|
||||
max_cc = 120
|
||||
max_cc = 121
|
||||
|
||||
kernel_schedulers = [
|
||||
KernelScheduleType.BlockwiseTmaWarpSpecializedCooperativeSm120,
|
||||
@@ -10924,7 +11229,11 @@ def GenerateSM120_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version,
|
||||
gemm_kind = gemm_kind)
|
||||
|
||||
def GenerateSM100(manifest, cuda_version):
|
||||
arch_family_cc = ['100f', '101f']
|
||||
arch_family_cc = ['100f', '101f', '103a']
|
||||
if CudaToolkitVersionSatisfies(cuda_version, 13, 0):
|
||||
for old_cc, new_cc in [('101f', '110f')]:
|
||||
arch_family_cc = [cc.replace(old_cc, new_cc) for cc in arch_family_cc]
|
||||
|
||||
#
|
||||
# Dense Gemm
|
||||
#
|
||||
@@ -10966,8 +11275,11 @@ def GenerateSM100(manifest, cuda_version):
|
||||
# Block Scaled Gemm
|
||||
#
|
||||
GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version)
|
||||
GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockScaledUniversal3x)
|
||||
GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version)
|
||||
GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockScaledUniversal3x)
|
||||
|
||||
GenerateSM103_TensorOp_fp4_ultra_UMMA_gemm_with_block_scaled(manifest, cuda_version)
|
||||
#
|
||||
# Conv
|
||||
#
|
||||
@@ -11413,7 +11725,6 @@ def numeric_log_level(log_level: str) -> int:
|
||||
raise ValueError(f'Invalid log level: {log_level}')
|
||||
return numeric_level
|
||||
|
||||
|
||||
# This function for defining the ArgumentParser is used to make it easy for the CUTLASS Python interface
|
||||
# to leverage the functionality in this file without running this script via a shell prompt.
|
||||
def define_parser():
|
||||
@@ -11438,6 +11749,11 @@ def define_parser():
|
||||
parser.add_argument("--filter-by-cc", default='True', type=str, help='If enabled, kernels whose compute capability range is not satisfied by the build target are excluded.')
|
||||
parser.add_argument("--cuda-version", default="11.0.0", help="Semantic version string of CUDA Toolkit")
|
||||
parser.add_argument('--kernel-filter-file', type=str, default=None, required=False, help='Full path of filter file')
|
||||
parser.add_argument('--heuristics-problems-file', type=str, default=None, required=False, help='Full path of heuristics problem size description file, as a json list')
|
||||
parser.add_argument('--heuristics-testlist-file', type=str, default=None, required=False, help='Full path of heuristics testlist CSV file, to be passed to cutlass_profiler')
|
||||
parser.add_argument('--heuristics-gpu', type=str, default=None, required=False, help='GPU to use for evaluating heuristics offline. None or `auto` to autodetect using cuda', choices=['', 'auto', 'H100_SXM', 'H100_PCIE', 'H100_NVL', 'H200_SXM', 'H20_SXM', 'B200', 'GB200_NVL', 'RTX_5080', 'RTX_5090', 'RTX_PRO_6000'])
|
||||
parser.add_argument('--heuristics-configs-per-problem', type=int, default=10, required=False, help='Number of kernel configs to generate for each problem in the problem list')
|
||||
parser.add_argument('--heuristics-restrict-kernels', action='store_true', help='Restrict heuristics mode to use only the default set of kernels emitted by generator.py')
|
||||
parser.add_argument('--selected-kernel-list', type=str, default=None, required=False,
|
||||
help='Specify the output log file containing all enabled kernels in this build')
|
||||
parser.add_argument("--interface-dir", default=None, required=False, help="Interface header to kernels")
|
||||
@@ -11460,6 +11776,9 @@ if __name__ == "__main__":
|
||||
|
||||
archs = args.architectures.split(';')
|
||||
|
||||
if args.heuristics_problems_file:
|
||||
filter_manifest_and_write_heuristics_file(manifest, args)
|
||||
|
||||
GenerateSM50(manifest, args.cuda_version)
|
||||
GenerateSM60(manifest, args.cuda_version)
|
||||
GenerateSM61(manifest, args.cuda_version)
|
||||
@@ -11468,17 +11787,20 @@ if __name__ == "__main__":
|
||||
GenerateSM80(manifest, args.cuda_version)
|
||||
GenerateSM89(manifest, args.cuda_version)
|
||||
GenerateSM90(manifest, args.cuda_version)
|
||||
|
||||
|
||||
blackwell_arch_list = [
|
||||
"100a", "100f",
|
||||
"101a", "101f",
|
||||
"120a", "120f"
|
||||
"103a", "103f",
|
||||
"110a", "110f",
|
||||
"120a", "120f",
|
||||
"121a", "121f",
|
||||
]
|
||||
blackwell_enabled_arch = any(arch in blackwell_arch_list for arch in archs)
|
||||
if blackwell_enabled_arch:
|
||||
GenerateSM100(manifest, args.cuda_version)
|
||||
GenerateSM120(manifest, args.cuda_version)
|
||||
|
||||
|
||||
if 'library' in args.generator_target.split(','):
|
||||
manifest.emit(GeneratorTarget.Library)
|
||||
|
||||
|
||||
414
python/cutlass_library/heuristics.py
Normal file
414
python/cutlass_library/heuristics.py
Normal file
@@ -0,0 +1,414 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Utilities for selecting CUTLASS library kernels based on problem description
|
||||
"""
|
||||
import json
|
||||
import csv
|
||||
|
||||
try:
|
||||
if CUTLASS_IGNORE_PACKAGE:
|
||||
raise ImportError("Disabling attempt to import cutlass_library")
|
||||
from cutlass_library.library import *
|
||||
from cutlass_library.generator import *
|
||||
from cutlass_library.heuristics_provider import *
|
||||
except ImportError:
|
||||
from library import *
|
||||
from generator import *
|
||||
from heuristics_provider import *
|
||||
|
||||
try:
|
||||
from .sm90_utils import (
|
||||
get_valid_schedules,
|
||||
generate_data_types_from_math_instruction,
|
||||
fix_alignments,
|
||||
)
|
||||
except ImportError:
|
||||
from sm90_utils import (
|
||||
get_valid_schedules,
|
||||
generate_data_types_from_math_instruction,
|
||||
fix_alignments,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
dtype_map = {v: k for k, v in DataTypeNames.items()}
|
||||
|
||||
def serialize_heuristics_results_to_json(problems_with_configs, outfile_path):
|
||||
"""
|
||||
Utilitiy function to write heuristics results to a json file for debug
|
||||
|
||||
args:
|
||||
problems_with_configs: List of problems provided to the heuristic, with a list of operations added to each problem dict
|
||||
outfile_path: Outfile path
|
||||
|
||||
returns:
|
||||
None
|
||||
"""
|
||||
pc_copy = problems_with_configs.copy()
|
||||
for p in pc_copy:
|
||||
for k, v in p.items():
|
||||
if isinstance(v, DataType):
|
||||
p[k] = DataTypeNames[v]
|
||||
elif isinstance(v, LayoutType):
|
||||
p[k] = ShortLayoutTypeNames[v]
|
||||
configs = p['configs']
|
||||
for c in configs:
|
||||
for k, v in c.items():
|
||||
if isinstance(v, DataType):
|
||||
c[k] = DataTypeNames[v]
|
||||
elif isinstance(v, LayoutType):
|
||||
c[k] = ShortLayoutTypeNames[v]
|
||||
with open(outfile_path, 'w') as f:
|
||||
json.dump(pc_copy, f, indent=2)
|
||||
|
||||
def get_single_gemm_config(m, n, k, batch_count, layouts, dtypes, alignment_a, alignment_b, voidC=False, use_fast_acc=True, count=1, provider=None):
|
||||
"""
|
||||
Get heuristic-suggested GEMM kernel configurations for a single GEMM problem.
|
||||
|
||||
args:
|
||||
m, n, k: GEMM dimensions
|
||||
batch_count: batch count
|
||||
layouts: tuple of layouts of type LayoutType
|
||||
use_fast_acc: Use fast accumulation for FP8. Ignored for other precisions
|
||||
count: Number of configs to return
|
||||
provider: Heuristics provider to use
|
||||
|
||||
returns:
|
||||
A list of dictionaries containing the suggested kernel configurations and additional info from the input required to define a Cutlass GemmOperation, with the following keys:
|
||||
- 'cta_tile_m', 'cta_tile_m', 'cta_tile_k': CTA tile size
|
||||
- 'instr_tile_m', 'instr_tile_n', 'instr_tile_k': Instruction tile size
|
||||
- 'stages': kernel pipeline stage count
|
||||
- 'cluster_m', 'cluster_n', 'cluster_k': cluster size
|
||||
- 'layout_a', 'layout_b': input tensor layouts of type LayoutType
|
||||
- 'alignment_a', 'alignment_b': input tensor alignments, in count of elements
|
||||
- 'dtype_a', 'dtype_b', 'dtype_acc': dtypes of a, b, and accumulator, of type DataType
|
||||
- 'swizzle_size' : suggested threadblock swizzle
|
||||
- 'split_k_slices': number of partitions of the k dimension for splitK
|
||||
- 'raster_order': raster order for CTAs over output tiles ('along_m' or 'along_n')
|
||||
"""
|
||||
if provider is None:
|
||||
provider = MatmulHeuristics()
|
||||
return provider.get_configs(m, n, k, batch_count, dtypes, layouts, alignment_a, alignment_b, voidC=voidC, use_fast_acc=use_fast_acc, count=count)
|
||||
|
||||
def get_gemm_configs(problems, provider=None, count=1):
|
||||
"""
|
||||
Get heuristic-suggested GEMM kernel configurations for a set of GEMM problems.
|
||||
|
||||
args:
|
||||
problems: List of dictionaries describing GEMM problems with the following keys:
|
||||
- 'm', 'n', 'k': Matrix dimensions (required)
|
||||
- 'dtype_a': Data type of matrix A (required)
|
||||
- 'dtype_b': Data type of matrix B (required)
|
||||
- 'dtype_c': Data type of matrix C (default: None)
|
||||
- 'dtype_d': Data type of matrix D (required)
|
||||
- 'dtype_acc': Compute data type (default 'f32')
|
||||
- 'layout': Operation layout (e.g. 'tnt')
|
||||
- 'alignment_a': Memory access granularity of A, in units of elements (default: 16 bytes equivalent elements)
|
||||
- 'alignment_b': Memory access granularity of B, in units of elements (default: 16 bytes equivalent elements)
|
||||
- 'alpha': Scalar multiplier for A*B (default: 1.0)
|
||||
- 'beta': Scalar multiplier for C (default: 0.0)
|
||||
- 'batch_count': Number of GEMM operations in batch (default: 1)
|
||||
- 'use_fast_acc': Enable fast accumulation for FP8 on Hopper (default: True)
|
||||
provider: Heuristics provider to use
|
||||
count: Number of configurations to return per problem (defualt: 1)
|
||||
|
||||
returns:
|
||||
A copy of the input dictionary, with key `configs` added containing the selected gemm configs
|
||||
"""
|
||||
ret = []
|
||||
|
||||
for problem in problems:
|
||||
problem = problem.copy()
|
||||
|
||||
try:
|
||||
m = problem['m']
|
||||
n = problem['n']
|
||||
k = problem['k']
|
||||
dtype_a = problem['dtype_a']
|
||||
dtype_b = problem['dtype_b']
|
||||
dtype_d = problem['dtype_d']
|
||||
layout = problem['layout']
|
||||
except KeyError as e:
|
||||
_LOGGER.error(f"Missing required parameter {e} for problem {problem}")
|
||||
raise
|
||||
|
||||
operation = problem.get('operation', 'gemm')
|
||||
batch_count = problem.get('batch_count', 1)
|
||||
dtype_acc = problem.get('dtype_acc', 'f32')
|
||||
dtype_c = problem.get('dtype_c', None)
|
||||
alpha = problem.get('alpha', 1.0)
|
||||
beta = problem.get('beta', 0.0)
|
||||
use_fast_acc = problem.get('use_fast_acc', True)
|
||||
|
||||
if operation != OperationKindNames[OperationKind.Gemm]:
|
||||
raise ValueError(f"Unsupported operation {operation}")
|
||||
if not (len(layout) == 3 and all(c in "nt" for c in layout)):
|
||||
raise ValueError(f"layout must be a 3-character string containing only 'n' or 't', got {layout}")
|
||||
layouts = tuple(LayoutType.RowMajor if l == 't' else LayoutType.ColumnMajor for l in layout)
|
||||
|
||||
try:
|
||||
dtype_list = [dtype_a.lower(), dtype_b.lower(), dtype_acc.lower(), dtype_c.lower() if dtype_c is not None else dtype_d.lower(), dtype_d.lower()]
|
||||
dtypes = tuple(dtype_map[dt] for dt in dtype_list)
|
||||
except KeyError as dt:
|
||||
_LOGGER.error(f"Unsupported data type: {dt}")
|
||||
raise
|
||||
|
||||
alignment_a = problem.get('alignment_a', 128 // DataTypeSize[dtypes[0]])
|
||||
alignment_b = problem.get('alignment_b', 128 // DataTypeSize[dtypes[1]])
|
||||
|
||||
configs = get_single_gemm_config(m, n, k, batch_count, layouts, dtypes, alignment_a, alignment_b, beta==0.0, use_fast_acc, count, provider)
|
||||
problem['configs'] = configs
|
||||
|
||||
ret.append(problem)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def generate_sm100_from_heuristics_configs(manifest, cuda_version, kernel_configs):
|
||||
"""
|
||||
Generate CUTLASS operations based on the list of configs provided by the heuristic provider
|
||||
|
||||
args:
|
||||
manifest: manifest argument to which to add operations, or None to just return the operations without a manifest (for pruning an existing manifest)
|
||||
cuda_version: Cuda compiler version for generating cutlass operations
|
||||
kernel_configs: list of configs generated by the heuristic
|
||||
|
||||
returns:
|
||||
(configs, operations): a list of heuristic-provided kernel configs along with a one-to-one corresponding list of the generated operations
|
||||
"""
|
||||
min_cc = 100
|
||||
max_cc = 101
|
||||
if manifest is None:
|
||||
# Use a dummy manifest so we can use existing CreateGemmOperator functions
|
||||
manifest = Manifest()
|
||||
|
||||
configs = []
|
||||
operations = []
|
||||
for config in kernel_configs:
|
||||
layout = ([config['layout_a'], config['alignment_a']], [config['layout_b'], config['alignment_b']], [config['layout_d'], 128 // DataTypeSize[config['dtype_d']]])
|
||||
element_a, element_b, element_accumulator, element_c, element_d = config['dtype_a'], config['dtype_b'], config['dtype_acc'], config['dtype_c'], config['dtype_d']
|
||||
|
||||
# nvMMH assumes 2sm instruction for !(cluster_m % 2)
|
||||
is_2sm = config['cluster_m'] % 2 == 0
|
||||
instruction_shape = [(2 * config['cta_tile_m']) if is_2sm else config['cta_tile_m'], config['cta_tile_n'], config['cta_tile_k'] // 4]
|
||||
math_instruction = MathInstruction(
|
||||
instruction_shape,
|
||||
element_a, element_b, element_accumulator,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add
|
||||
)
|
||||
|
||||
data_types = [
|
||||
{
|
||||
"a_type" : math_instruction.element_a,
|
||||
"b_type" : math_instruction.element_b,
|
||||
"c_type" : DataType.void if config['voidC'] else math_instruction.element_accumulator,
|
||||
"d_type" : element_d,
|
||||
"acc_type" : math_instruction.element_accumulator,
|
||||
"epi_type" : math_instruction.element_accumulator,
|
||||
}
|
||||
]
|
||||
|
||||
tile_multiplier = (config['cluster_m'] // (2 if is_2sm else 1), config['cluster_n'], config['cluster_k'])
|
||||
tile_description = TileDescription(
|
||||
[instruction_shape[0] * tile_multiplier[0],
|
||||
instruction_shape[1] * tile_multiplier[1],
|
||||
instruction_shape[2] * 4 * tile_multiplier[2]],
|
||||
0,
|
||||
[4,1,1],
|
||||
math_instruction,
|
||||
min_cc,
|
||||
max_cc,
|
||||
cluster_shape=(config['cluster_m'], config['cluster_n'], config['cluster_k'])
|
||||
)
|
||||
|
||||
schedules = []
|
||||
if is_2sm:
|
||||
schedules.append([KernelScheduleType.TmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm])
|
||||
else:
|
||||
schedules.append([KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm])
|
||||
|
||||
for o in CreateGemmUniversal3xOperator(manifest, [layout], [tile_description], data_types, schedules, tile_schedulers=[TileSchedulerType.Default, TileSchedulerType.StreamK], gemm_kind=GemmKind.Universal3x):
|
||||
configs.append(config)
|
||||
operations.append(o)
|
||||
|
||||
|
||||
return configs, operations
|
||||
|
||||
|
||||
def generate_sm90_from_heuristics_configs(manifest, cuda_version, kernel_configs):
|
||||
"""
|
||||
Generate CUTLASS operations based on the list of configs provided by the heuristic provider
|
||||
|
||||
args:
|
||||
manifest: manifest argument to which to add operations, or None to just return the operations without a manifest (for pruning an existing manifest)
|
||||
cuda_version: Cuda compiler version for generating cutlass operations
|
||||
kernel_configs: list of configs generated by the heuristic
|
||||
|
||||
returns:
|
||||
(configs, operations): a list of heuristic-provided kernel configs along with a one-to-one corresponding list of the generated operations
|
||||
"""
|
||||
min_cc, max_cc = 90, 90
|
||||
|
||||
if manifest is None:
|
||||
# Use a dummy manifest so we can use existing CreateGemmOperator functions
|
||||
manifest = Manifest()
|
||||
|
||||
configs = []
|
||||
operations = []
|
||||
for config in kernel_configs:
|
||||
|
||||
is_aligned = (config['alignment_a'] * DataTypeSize[config['dtype_a']] >= 128) and (config['alignment_b'] * DataTypeSize[config['dtype_b']] >= 128)
|
||||
layout = ([config['layout_a'], config['alignment_a']], [config['layout_b'], config['alignment_b']], [LayoutType.ColumnMajor, 1])
|
||||
element_a, element_b, element_accumulator, element_c, element_d = config['dtype_a'], config['dtype_b'], config['dtype_acc'], config['dtype_c'], config['dtype_d']
|
||||
|
||||
# instr shape and warp config are unused for emitting 3x collective builder code
|
||||
dummy_instr_shape = [0, 0, 0]
|
||||
math_instruction = MathInstruction(
|
||||
dummy_instr_shape,
|
||||
element_a, element_b, element_accumulator,
|
||||
OpcodeClass.TensorOp,
|
||||
MathOperation.multiply_add
|
||||
)
|
||||
|
||||
data_types = generate_data_types_from_math_instruction(math_instruction, element_source=element_c, element_dest=element_d)
|
||||
if is_aligned:
|
||||
layout = fix_alignments(data_types, layout, alignment_bits=128)
|
||||
|
||||
# instr shape and warp config are unused for emitting 3x collective builder code
|
||||
dummy_warp_count = [0, 0, 0]
|
||||
tile_description = TileDescription(
|
||||
[config['cta_tile_m'], config['cta_tile_n'], config['cta_tile_k']],
|
||||
0,
|
||||
dummy_warp_count,
|
||||
math_instruction,
|
||||
min_cc,
|
||||
max_cc,
|
||||
cluster_shape=(config['cluster_m'], config['cluster_n'], config['cluster_k'])
|
||||
)
|
||||
|
||||
schedules, stream_k_schedules = get_valid_schedules(
|
||||
tile_description=tile_description,
|
||||
cuda_version=cuda_version,
|
||||
is_aligned=is_aligned,
|
||||
data_types=data_types,
|
||||
instantiation_level=9000, # don't prune schedules: we didn't get any schedule suggestion from the heuristic
|
||||
layout=layout,
|
||||
gemm_kind=GemmKind.Universal3x,
|
||||
enable_fp8_fast_acc=config['use_fast_acc']
|
||||
)
|
||||
|
||||
if len(schedules):
|
||||
for o in CreateGemmUniversal3xOperator(manifest, [layout], [tile_description], data_types, schedules, gemm_kind=GemmKind.Universal3x):
|
||||
configs.append(config)
|
||||
operations.append(o)
|
||||
|
||||
if len(stream_k_schedules):
|
||||
for o in CreateGemmUniversal3xOperator(manifest, [layout], [tile_description], data_types,
|
||||
stream_k_schedules,
|
||||
tile_schedulers=[TileSchedulerType.StreamK]):
|
||||
configs.append(config)
|
||||
operations.append(o)
|
||||
|
||||
|
||||
return configs, operations
|
||||
|
||||
def filter_manifest_and_write_heuristics_file(manifest, args):
|
||||
"""
|
||||
Prune a manifest according to heuristics suggestions from the problems file
|
||||
|
||||
args:
|
||||
manifest: Cutlass manifest to prune
|
||||
args: generator.py args, requires:
|
||||
- args.heuristics_problems_file
|
||||
- args.heuristics_gpu
|
||||
- args.heuristics_testlist_file
|
||||
|
||||
returns:
|
||||
A list of dictionaries, each of which has information about an operation and a problem from the input problems
|
||||
"""
|
||||
heuristics_problems = []
|
||||
with open(args.heuristics_problems_file, 'r') as f:
|
||||
heuristics_problems = json.load(f)
|
||||
gpu = None if (args.heuristics_gpu == "auto" or args.heuristics_gpu == "") else args.heuristics_gpu
|
||||
mmh = MatmulHeuristics(gpu=gpu)
|
||||
if any(('100' in arch) for arch in args.architectures.split(';')):
|
||||
mmh.set_cta_div_n(64)
|
||||
problems_with_configs = get_gemm_configs(heuristics_problems, provider=mmh, count=args.heuristics_configs_per_problem)
|
||||
|
||||
all_configs_and_operations = []
|
||||
operations = []
|
||||
for problem in problems_with_configs:
|
||||
if any('90' in arch for arch in args.architectures.split(';')):
|
||||
problem_configs, problem_operations = generate_sm90_from_heuristics_configs(None if args.heuristics_restrict_kernels else manifest, args.cuda_version, problem['configs'])
|
||||
if any(('100' in arch) or ('101' in arch) for arch in args.architectures.split(';')):
|
||||
problem_configs, problem_operations = generate_sm100_from_heuristics_configs(None if args.heuristics_restrict_kernels else manifest, args.cuda_version, problem['configs'])
|
||||
|
||||
operations += problem_operations
|
||||
problem_without_configs = {k: v for k, v in problem.items() if k != 'configs'}
|
||||
with_problem_size = [{'operation_name': o.procedural_name(), **problem_without_configs, **c} for c, o in zip(problem_configs, problem_operations)]
|
||||
all_configs_and_operations += with_problem_size
|
||||
|
||||
for operation in operations:
|
||||
manifest.add_kernel_filter(f"^{operation.procedural_name()}$")
|
||||
if not all_configs_and_operations:
|
||||
raise Exception("No valid configurations generated")
|
||||
write_profiler_testlist_to_csv(all_configs_and_operations, args.heuristics_testlist_file)
|
||||
return all_configs_and_operations
|
||||
|
||||
def write_profiler_testlist_to_csv(configs_list, outfile_path):
|
||||
"""
|
||||
Write a list of configs to a testlist to be consumed by cutlass_profiler
|
||||
|
||||
args:
|
||||
configs_list: List of kernel configs along with runtime arguments and any other columns to include in the CSV, expressed as a list of dictionaries
|
||||
outfile_path: Outfile path
|
||||
|
||||
returns:
|
||||
None
|
||||
"""
|
||||
profiler_testlist = configs_list.copy()
|
||||
for c in profiler_testlist:
|
||||
for k, v in c.items():
|
||||
if isinstance(v, DataType):
|
||||
c[k] = DataTypeNames[v]
|
||||
elif isinstance(v, LayoutType):
|
||||
c[k] = ShortLayoutTypeNames[v]
|
||||
|
||||
with open(outfile_path, mode='w', newline='') as ofile:
|
||||
k_names = profiler_testlist[0].keys()
|
||||
|
||||
writer = csv.DictWriter(ofile, fieldnames=k_names)
|
||||
writer.writeheader()
|
||||
writer.writerows(profiler_testlist)
|
||||
168
python/cutlass_library/heuristics_provider.py
Normal file
168
python/cutlass_library/heuristics_provider.py
Normal file
@@ -0,0 +1,168 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Providers for kernel selection heuristics
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import glob
|
||||
import logging
|
||||
import ctypes
|
||||
import functools
|
||||
|
||||
from library import DataType, LayoutType
|
||||
|
||||
class MatmulHeuristics:
|
||||
|
||||
def __init__(self, gpu = None):
|
||||
import nvMatmulHeuristics
|
||||
self.mmh_lib = nvMatmulHeuristics
|
||||
self.gpu = gpu
|
||||
|
||||
if 'CUTLASS_NVMMH_SO_PATH' in os.environ:
|
||||
nvmmhInterfaceEx = functools.partial(self.mmh_lib.NvMatmulHeuristicsInterfaceEx, path=os.environ['CUTLASS_NVMMH_SO_PATH'])
|
||||
else:
|
||||
nvmmhInterfaceEx = self.mmh_lib.NvMatmulHeuristicsInterfaceEx
|
||||
|
||||
self.lh = nvmmhInterfaceEx(
|
||||
backend=self.mmh_lib.NvMatmulHeuristicsTarget["CUTLASS3"],
|
||||
flags=self.mmh_lib.NvMatmulHeuristicsFlags.PERF_MODEL_BASED_AUTO_TUNING,
|
||||
load_discovery_implicitly=True,
|
||||
gpu=self.mmh_lib.NvMatmulHeuristicsNvidiaGpu[self.gpu] if self.gpu else None
|
||||
)
|
||||
self.backend = self.lh.createBackend(self.mmh_lib.NvMatmulHeuristicsTarget["CUTLASS3"])
|
||||
|
||||
def _layout_from_cutlass(self, layouts):
|
||||
assert(len(layouts)==3)
|
||||
full_layout_str = ''.join('t' if l == LayoutType.RowMajor else 'n' for l in layouts)
|
||||
input_layouts = full_layout_str[:2].upper()
|
||||
lh_layout = input_layouts + '_' + str("ROW_MAJOR" if full_layout_str[-1]=='t' else "COL_MAJOR")
|
||||
return self.mmh_lib.NvMatmulHeuristicsMatmulLayout[lh_layout]
|
||||
|
||||
def _precision_from_cutlass_dtypes(self, dtypes):
|
||||
dtype_to_cublas = {
|
||||
DataType.f64: 'D',
|
||||
DataType.f32: 'S',
|
||||
DataType.f16: 'H',
|
||||
DataType.bf16: 'T',
|
||||
DataType.e4m3: 'Q',
|
||||
DataType.e5m2: 'R',
|
||||
DataType.s32: 'I',
|
||||
DataType.s8: 'B',
|
||||
}
|
||||
|
||||
dtype_a, dtype_b, dtype_compute, dtype_c, dtype_d = dtypes
|
||||
|
||||
a_c = dtype_to_cublas[dtype_a]
|
||||
|
||||
if a_c.lower() != 'q':
|
||||
return a_c + dtype_to_cublas[dtype_compute] + dtype_to_cublas[dtype_d]
|
||||
else:
|
||||
return a_c + dtype_to_cublas[dtype_b] + dtype_to_cublas[dtype_c] + dtype_to_cublas[dtype_compute] + dtype_to_cublas[dtype_d]
|
||||
|
||||
def set_cta_div_n(self, div_n):
|
||||
cta_n_div_requirement = ctypes.c_int(div_n)
|
||||
self.lh.setBackendValueProperty(
|
||||
self.backend,
|
||||
self.mmh_lib.NvMatmulHeuristicsBackendProperty.CTA_TILE_N_DIV_REQUIREMENT,
|
||||
ctypes.byref(cta_n_div_requirement),
|
||||
ctypes.sizeof(cta_n_div_requirement)
|
||||
)
|
||||
|
||||
def set_cta_div_m(self, div_m):
|
||||
cta_m_div_requirement = ctypes.c_int(div_m)
|
||||
self.lh.setBackendValueProperty(
|
||||
self.backend,
|
||||
self.mmh_lib.NvMatmulHeuristicsBackendProperty.CTA_TILE_M_DIV_REQUIREMENT,
|
||||
ctypes.byref(cta_m_div_requirement),
|
||||
ctypes.sizeof(cta_m_div_requirement)
|
||||
)
|
||||
|
||||
def get_configs(self, m, n, k, batch_count, dtypes, layouts, align_a, align_b, voidC=False, use_fast_acc=True, count=1):
|
||||
if use_fast_acc:
|
||||
disable_fast_acc_for_fp8 = ctypes.c_int(0)
|
||||
else:
|
||||
disable_fast_acc_for_fp8 = ctypes.c_int(1)
|
||||
self.lh.setBackendValueProperty(
|
||||
self.backend,
|
||||
self.mmh_lib.NvMatmulHeuristicsBackendProperty.DISABLE_FAST_ACC_FOR_FP8,
|
||||
ctypes.byref(disable_fast_acc_for_fp8),
|
||||
ctypes.sizeof(disable_fast_acc_for_fp8)
|
||||
)
|
||||
|
||||
precision = self._precision_from_cutlass_dtypes(dtypes)
|
||||
layout = self._layout_from_cutlass(layouts)
|
||||
|
||||
matmul_problem = self.lh.makeNvMatmulHeuristicsProblem(m, n, k, layout, batch_count)
|
||||
configs = self.lh.getEx(matmul_problem, count, self.backend, precision=precision)
|
||||
|
||||
ret = []
|
||||
for c in configs:
|
||||
kernel = c['kernel']
|
||||
problem = c['problem']
|
||||
|
||||
r = {}
|
||||
r['estimated_runtime'] = c['runtime']
|
||||
r['cta_tile_m'] = kernel.cta_tile_m
|
||||
r['cta_tile_n'] = kernel.cta_tile_n
|
||||
r['cta_tile_k'] = kernel.cta_tile_k
|
||||
r['instr_tile_m'] = kernel.instr_tile_m
|
||||
r['instr_tile_n'] = kernel.instr_tile_n
|
||||
r['instr_tile_k'] = kernel.instr_tile_k
|
||||
r['warp_tile_m'] = kernel.warp_tile_m
|
||||
r['warp_tile_n'] = kernel.warp_tile_n
|
||||
r['warp_tile_k'] = kernel.warp_tile_k
|
||||
r['cluster_m'] = kernel.cluster_m
|
||||
r['cluster_n'] = kernel.cluster_n
|
||||
r['cluster_k'] = 1
|
||||
r['layout_a'] = layouts[0]
|
||||
r['layout_b'] = layouts[1]
|
||||
r['layout_d'] = layouts[2]
|
||||
r['dtype_a'] = dtypes[0]
|
||||
r['dtype_b'] = dtypes[1]
|
||||
r['dtype_acc'] = dtypes[2]
|
||||
r['dtype_c'] = dtypes[3]
|
||||
r['dtype_d'] = dtypes[4]
|
||||
r['alignment_a'] = align_a
|
||||
r['alignment_b'] = align_b
|
||||
r['swizzle_size'] = kernel.swizzle_factor
|
||||
r['raster_order'] = 'along_m' if kernel.cta_order==0 else 'along_n'
|
||||
r['split_k_slices'] = kernel.split_k
|
||||
r['use_fast_acc'] = use_fast_acc
|
||||
r['voidC'] = voidC
|
||||
|
||||
ret.append(r)
|
||||
|
||||
return ret
|
||||
|
||||
@@ -546,6 +546,22 @@ class KernelScheduleType(enum.Enum):
|
||||
Nvf4TmaWarpSpecialized1SmSm100 = enum_auto()
|
||||
Nvf4TmaWarpSpecialized2SmSm100 = enum_auto()
|
||||
|
||||
# FP4 Ultra
|
||||
BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103 = enum_auto()
|
||||
BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103 = enum_auto()
|
||||
BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103 = enum_auto()
|
||||
BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103 = enum_auto()
|
||||
|
||||
BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch = enum_auto()
|
||||
BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch = enum_auto()
|
||||
BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch = enum_auto()
|
||||
BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch = enum_auto()
|
||||
|
||||
BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch = enum_auto()
|
||||
BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch = enum_auto()
|
||||
BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch = enum_auto()
|
||||
BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch = enum_auto()
|
||||
|
||||
Mxf8f6f4TmaWarpSpecializedCooperativeSm120 = enum_auto()
|
||||
Mxf8f6f4TmaWarpSpecializedPingpongSm120 = enum_auto()
|
||||
Nvf4TmaWarpSpecializedCooperativeSm120 = enum_auto()
|
||||
@@ -603,6 +619,22 @@ KernelScheduleTag = {
|
||||
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100',
|
||||
KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100',
|
||||
|
||||
# FP4 Ultra
|
||||
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103',
|
||||
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103',
|
||||
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103',
|
||||
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103',
|
||||
|
||||
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch',
|
||||
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch',
|
||||
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch',
|
||||
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch',
|
||||
|
||||
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch',
|
||||
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch',
|
||||
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch',
|
||||
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch',
|
||||
|
||||
KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative',
|
||||
KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum',
|
||||
KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong',
|
||||
@@ -677,6 +709,21 @@ KernelScheduleSuffixes = {
|
||||
KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: '_o_vs16_1sm',
|
||||
KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: '_o_vs16_2sm',
|
||||
|
||||
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: '_o_vs16_1sm',
|
||||
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: '_o_vs16_2sm',
|
||||
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: '_o_vs32_1sm',
|
||||
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: '_o_vs32_2sm',
|
||||
|
||||
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch: '_o_vs16_1sm_nopf',
|
||||
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch: '_o_vs16_2sm_nopf',
|
||||
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch: '_o_vs32_1sm_nopf',
|
||||
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch: '_o_vs32_2sm_nopf',
|
||||
|
||||
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch: '_o_vs16_1sm_tmapf',
|
||||
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: '_o_vs16_2sm_tmapf',
|
||||
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: '_o_vs32_1sm_tmapf',
|
||||
KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: '_o_vs32_2sm_tmapf',
|
||||
|
||||
KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative: '_warpspecialized_cooperative',
|
||||
KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: '_warpspecialized_cooperative_fp8_fastaccum',
|
||||
KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong: '_warpspecialized_pingpong',
|
||||
@@ -713,8 +760,12 @@ class EpilogueScheduleType(enum.Enum):
|
||||
PtrArrayNoSmemWarpSpecialized = enum_auto()
|
||||
NoSmemWarpSpecialized1Sm = enum_auto()
|
||||
NoSmemWarpSpecialized2Sm = enum_auto()
|
||||
FastF32NoSmemWarpSpecialized1Sm = enum_auto()
|
||||
FastF32NoSmemWarpSpecialized2Sm = enum_auto()
|
||||
PtrArrayNoSmemWarpSpecialized1Sm = enum_auto()
|
||||
PtrArrayNoSmemWarpSpecialized2Sm = enum_auto()
|
||||
PtrArrayFastF32NoSmemWarpSpecialized1Sm = enum_auto()
|
||||
PtrArrayFastF32NoSmemWarpSpecialized2Sm = enum_auto()
|
||||
TmaWarpSpecialized = enum_auto()
|
||||
TmaWarpSpecializedCooperative = enum_auto()
|
||||
TmaWarpSpecialized1Sm = enum_auto()
|
||||
@@ -732,8 +783,12 @@ EpilogueScheduleTag = {
|
||||
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized',
|
||||
EpilogueScheduleType.NoSmemWarpSpecialized1Sm: 'cutlass::epilogue::NoSmemWarpSpecialized1Sm',
|
||||
EpilogueScheduleType.NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::NoSmemWarpSpecialized2Sm',
|
||||
EpilogueScheduleType.FastF32NoSmemWarpSpecialized1Sm: 'cutlass::epilogue::FastF32NoSmemWarpSpecialized1Sm',
|
||||
EpilogueScheduleType.FastF32NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::FastF32NoSmemWarpSpecialized2Sm',
|
||||
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized1Sm',
|
||||
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized2Sm',
|
||||
EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayFastF32NoSmemWarpSpecialized1Sm',
|
||||
EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayFastF32NoSmemWarpSpecialized2Sm',
|
||||
EpilogueScheduleType.TmaWarpSpecialized: 'cutlass::epilogue::TmaWarpSpecialized',
|
||||
EpilogueScheduleType.TmaWarpSpecializedCooperative: 'cutlass::epilogue::TmaWarpSpecializedCooperative',
|
||||
EpilogueScheduleType.TmaWarpSpecialized1Sm: 'cutlass::epilogue::TmaWarpSpecialized1Sm',
|
||||
@@ -752,8 +807,12 @@ EpilogueScheduleSuffixes = {
|
||||
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized: '_epi_nosmem',
|
||||
EpilogueScheduleType.NoSmemWarpSpecialized1Sm: '_epi_nosmem',
|
||||
EpilogueScheduleType.NoSmemWarpSpecialized2Sm: '_epi_nosmem',
|
||||
EpilogueScheduleType.FastF32NoSmemWarpSpecialized1Sm: '_epi_nosmem_fastf32',
|
||||
EpilogueScheduleType.FastF32NoSmemWarpSpecialized2Sm: '_epi_nosmem_fastf32',
|
||||
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm: '_epi_nosmem',
|
||||
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm: '_epi_nosmem',
|
||||
EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized1Sm: '_epi_nosmem_fastf32',
|
||||
EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized2Sm: '_epi_nosmem_fastf32',
|
||||
EpilogueScheduleType.TmaWarpSpecialized: '_epi_tma',
|
||||
EpilogueScheduleType.TmaWarpSpecializedCooperative: '_epi_tma',
|
||||
EpilogueScheduleType.TmaWarpSpecialized1Sm: '',
|
||||
|
||||
@@ -526,44 +526,49 @@ class Manifest:
|
||||
if args.filter_by_cc in ['false', 'False', '0']:
|
||||
self.filter_by_cc = False
|
||||
|
||||
if args.operations == 'all':
|
||||
self.operations_enabled = []
|
||||
else:
|
||||
operations_list = [
|
||||
OperationKind.Gemm
|
||||
, OperationKind.Conv2d
|
||||
, OperationKind.Conv3d
|
||||
, OperationKind.RankK
|
||||
, OperationKind.Trmm
|
||||
, OperationKind.Symm
|
||||
]
|
||||
self.operations_enabled = [x for x in operations_list if OperationKindNames[x] in args.operations.split(',')]
|
||||
if args.operations == 'all':
|
||||
self.operations_enabled = []
|
||||
else:
|
||||
operations_list = [
|
||||
OperationKind.Gemm
|
||||
, OperationKind.Conv2d
|
||||
, OperationKind.Conv3d
|
||||
, OperationKind.RankK
|
||||
, OperationKind.Trmm
|
||||
, OperationKind.Symm
|
||||
]
|
||||
self.operations_enabled = [x for x in operations_list if OperationKindNames[x] in args.operations.split(',')]
|
||||
|
||||
if args.kernels == 'all':
|
||||
self.kernel_names = []
|
||||
else:
|
||||
self.kernel_names = [x for x in args.kernels.split(',') if x != '']
|
||||
if args.kernels == 'all':
|
||||
self.kernel_names = []
|
||||
else:
|
||||
self.kernel_names = [x for x in args.kernels.split(',') if x != '']
|
||||
|
||||
self.ignore_kernel_names = [x for x in args.ignore_kernels.split(',') if x != '']
|
||||
self.exclude_kernel_names = [x for x in args.exclude_kernels.split(',') if x != '']
|
||||
self.ignore_kernel_names = [x for x in args.ignore_kernels.split(',') if x != '']
|
||||
self.exclude_kernel_names = [x for x in args.exclude_kernels.split(',') if x != '']
|
||||
|
||||
if args.kernel_filter_file is None:
|
||||
self.kernel_filter_list = []
|
||||
else:
|
||||
self.kernel_filter_list = self.get_kernel_filters(args.kernel_filter_file)
|
||||
_LOGGER.debug("Using {filter_count} kernel filters from {filter_file}".format(
|
||||
filter_count = len(self.kernel_filter_list),
|
||||
filter_file = args.kernel_filter_file))
|
||||
if args.kernel_filter_file is None:
|
||||
self.kernel_filter_list = []
|
||||
else:
|
||||
self.kernel_filter_list = self.get_kernel_filters(args.kernel_filter_file)
|
||||
_LOGGER.debug("Using {filter_count} kernel filters from {filter_file}".format(
|
||||
filter_count = len(self.kernel_filter_list),
|
||||
filter_file = args.kernel_filter_file))
|
||||
|
||||
self.operation_count = 0
|
||||
self.operations_by_name = {}
|
||||
self.disable_full_archs_compilation = args.disable_full_archs_compilation
|
||||
self.is_kernel_filter_set_to_all = args.instantiation_level == "max" and args.kernels != ''
|
||||
self.instantiation_level = 0
|
||||
try:
|
||||
self.instantiation_level = int(args.instantiation_level)
|
||||
except ValueError:
|
||||
self.instantiation_level = 0
|
||||
self.operation_count = 0
|
||||
self.operations_by_name = {}
|
||||
self.disable_full_archs_compilation = args.disable_full_archs_compilation
|
||||
self.is_kernel_filter_set_to_all = args.instantiation_level == "max" and args.kernels != ''
|
||||
self.instantiation_level = 0
|
||||
try:
|
||||
self.instantiation_level = int(args.instantiation_level)
|
||||
except ValueError:
|
||||
self.instantiation_level = 0
|
||||
|
||||
def add_kernel_filter(self, filter_str):
|
||||
filter_re = re.compile(filter_str)
|
||||
|
||||
self.kernel_filter_list.append(filter_re)
|
||||
|
||||
def get_sm90_instantiation_level(self, pruned_level=0, default_level=111, exhaustive_level=9992):
|
||||
# Non-negative integer which determines how many kernels are instantiated.
|
||||
|
||||
@@ -407,7 +407,7 @@ def generate_tile_descriptions_sm90(math_instructions, is_aligned: bool, level:
|
||||
|
||||
def is_tile_desc_compatible_with_cooperative(tile_description):
|
||||
# Cooperative kernels require a minimum CTA-M of 128
|
||||
return tile_description.threadblock_shape[0] >= 128
|
||||
return tile_description.threadblock_shape[0] % 128 == 0
|
||||
|
||||
|
||||
def can_tile_desc_use_shmem_in_epilogue(tile_description, data_types):
|
||||
|
||||
@@ -50,17 +50,17 @@ setup_pycute.perform_setup()
|
||||
|
||||
|
||||
setup(
|
||||
name='cutlass',
|
||||
version='3.4.0',
|
||||
name='cutlass_cppgen',
|
||||
version='4.0.0',
|
||||
description='CUTLASS Pythonic Interface',
|
||||
package_dir={'': '.'},
|
||||
packages=[
|
||||
'cutlass',
|
||||
'cutlass.emit',
|
||||
'cutlass.op',
|
||||
'cutlass.utils',
|
||||
'cutlass.backend',
|
||||
'cutlass.backend.utils'
|
||||
'cutlass_cppgen',
|
||||
'cutlass_cppgen.emit',
|
||||
'cutlass_cppgen.op',
|
||||
'cutlass_cppgen.utils',
|
||||
'cutlass_cppgen.backend',
|
||||
'cutlass_cppgen.backend.utils'
|
||||
],
|
||||
setup_requires=['pybind11'],
|
||||
install_requires=[
|
||||
|
||||
Reference in New Issue
Block a user