CUTLASS 3.2 (#1024)

* CUTLASS 3.2
This commit is contained in:
ANIKET SHIVAM
2023-08-07 14:50:32 -10:00
committed by GitHub
parent a0d787b746
commit 4575443d44
392 changed files with 47559 additions and 7940 deletions

View File

@@ -71,12 +71,15 @@ from library import (
DataType,
DataTypeSize,
EpilogueFunctor,
EpilogueScheduleSuffixes,
EpilogueScheduleTag,
EpilogueScheduleType,
GemmKind,
LayoutTag,
LayoutType,
KernelScheduleSuffixes,
KernelScheduleType,
KernelScheduleTag,
KernelScheduleType,
MathInstruction,
MathOperation,
OpcodeClass,
@@ -85,6 +88,9 @@ from library import (
SwizzlingFunctor,
TensorDescription,
TileDescription,
TileSchedulerSuffixes,
TileSchedulerTag,
TileSchedulerType
)
this = sys.modules[__name__]
@@ -106,11 +112,12 @@ from cutlass.backend.utils.device import device_cc
this.option_registry = OptionRegistry(device_cc())
this.__version__ = '3.1.0'
this.__version__ = '3.2.0'
from cutlass.backend import get_memory_pool
from cutlass.emit.pytorch import pytorch
from cutlass.op.gemm import Gemm
from cutlass.op.conv import Conv2d, Conv2dFprop, Conv2dDgrad, Conv2dWgrad
from cutlass.op.gemm_grouped import GroupedGemm
from cutlass.op.op import OperationBase

View File

@@ -118,6 +118,7 @@ class GenericMainloopArguments3x_(ctypes.Structure):
("stride_A", StrideBatched_),
("ptr_B", ctypes.c_void_p),
("stride_B", StrideBatched_),
("mma_promotion_interval", ctypes.c_int)
]
@@ -148,12 +149,14 @@ def get_mainloop_arguments_3x(
("stride_A", StrideBatched_),
("ptr_B", ctypes.c_void_p),
("stride_B", StrideBatched_),
("mma_promotion_interval", ctypes.c_int)
]
@staticmethod
def from_generic_mainloop_args(args: GenericMainloopArguments3x_):
return _MainloopArgumentsTma(
args.ptr_A, args.stride_A, args.ptr_B, args.stride_B,
args.mma_promotion_interval
)
class _MainloopArgumentsMultistage(ctypes.Structure):
@@ -203,15 +206,23 @@ def get_gemm_arguments_3x(mainloop_arguments, epilogue_functor):
("stride_D", StrideBatched_),
]
class _HardwareInfo(ctypes.Structure):
_fields_ = [
("device_id", ctypes.c_int),
("sm_count", ctypes.c_int)
]
class _GemmArguments(ctypes.Structure):
_fields_ = [
("mode", ctypes.c_int),
("problem_size", GemmCoordBatched_),
("mainloop", mainloop_arguments),
("epilogue", _EpilogueArguments)
("epilogue", _EpilogueArguments),
("hw_info", _HardwareInfo),
("splits", ctypes.c_int)
]
return _GemmArguments, _EpilogueArguments, _EpilogueOutputOpParams
return _GemmArguments, _EpilogueArguments, _EpilogueOutputOpParams, _HardwareInfo
def get_gemm_arguments(epilogue_functor):

View File

@@ -39,16 +39,37 @@ import tempfile
from cuda import cuda, nvrtc
import cutlass_bindings
from cutlass import CACHE_FILE, CUDA_INSTALL_PATH, CUTLASS_PATH
from cutlass import CACHE_FILE, CUDA_INSTALL_PATH, CUTLASS_PATH, logger
from cutlass.backend.gemm_operation import GemmOperationUniversal
from cutlass.backend.library import ApiVersion
from cutlass.backend.utils.device import device_cc
from cutlass.backend.utils.software import SubstituteTemplate
import subprocess
IncludeTemplate = r"""#include "${include}"
"""
def compile_with_nvcc(cmd, source, error_file):
succeed = True
try:
subprocess.check_output(cmd, stderr=subprocess.STDOUT, shell=True)
except subprocess.CalledProcessError as e:
error_message = e.output.decode()
with open(error_file, "w") as error_out:
error_log = "Compilation error for the following kernel: \n"
error_log += source
error_log += "\nError Message:\n"
error_log += error_message
error_out.write(error_log)
succeed = False
if not succeed:
# Print the error log to stdout if log level is set to warning or higher
# verbosity. Otherwise, simply point to the error log file.
logger.warning(error_log)
raise Exception(f"Invalid Kernel. See '{error_file}' for details.")
class CompilationOptions:
"""
Compilation options.
@@ -129,20 +150,24 @@ class ArtifactManager:
connection.commit()
cursor.close()
self._nvrtc_compile_options = ["-std=c++17", "-default-device"]
self._nvcc_compile_options = [
"-std=c++17",
"--expt-relaxed-constexpr",
"-Xcudafe --diag_suppress=esa_on_defaulted_function_ignored",
]
self.nvcc()
self.compiled_cache_device = cutlass_bindings.CompileCache()
self.compiled_cache_host = cutlass_bindings.CompileCache()
def nvrtc(self):
self.backend = "nvrtc"
self.default_compile_options = ["-std=c++17", "-default-device"]
self.default_compile_options = self._nvrtc_compile_options
def nvcc(self):
self.backend = "nvcc"
self.default_compile_options = [
"-std=c++17",
"--expt-relaxed-constexpr",
"-Xcudafe --diag_suppress=esa_on_defaulted_function_ignored",
]
self.default_compile_options = self._nvcc_compile_options
def insert_operation(self, op_key, cubin, hostfile, op_name, op_attrs):
connection = sqlite3.connect(CACHE_FILE)
cursor = connection.cursor()
@@ -200,7 +225,7 @@ class ArtifactManager:
self.compiled_cache_host.insert(key, compiled_host_fns)
return True
def emit_compile_(self, operation_list, compilation_options):
def emit_compile_(self, operation_list, compilation_options, host_compilation_options):
"""
Compile a list of kernels and store them into database
"""
@@ -299,7 +324,7 @@ class ArtifactManager:
"tarfile": temp_cubin.name,
}
cmd = SubstituteTemplate(cmd_template, values)
os.system(cmd)
compile_with_nvcc(cmd, source_buffer_device, "./cutlass_python_compilation_device_error.txt")
# load the cubin image
with open(temp_cubin.name, "rb") as file:
@@ -314,7 +339,7 @@ class ArtifactManager:
cmd_template,
{
"cuda_install_path": CUDA_INSTALL_PATH,
"options": compilation_options.get_str(),
"options": host_compilation_options.get_str(),
},
)
@@ -323,29 +348,31 @@ class ArtifactManager:
prefix="host_func", suffix=".so", delete=True)
cmd += " - -shared -o %s -lcudart -lcuda" % temp.name
os.system(cmd)
compile_with_nvcc(cmd, source_buffer_host, error_file="./cutlass_python_compilation_host_error.txt")
host_lib = ctypes.CDLL(temp.name)
return cubin_image, host_lib, temp
def add_module(self, operations, compile_options=None):
def add_module(self, operations, compile_options=None, bypass_cache=False):
"""
Insert a new compiled device module
"""
if compile_options is None:
include_paths = [
CUDA_INSTALL_PATH + "/include",
CUTLASS_PATH + "/include",
CUTLASS_PATH + "/tools/util/include",
CUTLASS_PATH + "/python/cutlass/cpp/include",
]
include_paths = [
CUDA_INSTALL_PATH + "/include",
CUTLASS_PATH + "/include",
CUTLASS_PATH + "/tools/util/include",
CUTLASS_PATH + "/python/cutlass/cpp/include",
]
if device_cc() is not None:
arch = device_cc()
else:
# Find the maximum arch tag among the provided operations and compile for that target.
# Since we are compiling to .cubin files, only one architecture may be specified.
arch = max([op.arch for op in operations])
if device_cc() is not None:
arch = device_cc()
else:
# Find the maximum arch tag among the provided operations and compile for that target.
# Since we are compiling to .cubin files, only one architecture may be specified.
arch = max([op.arch for op in operations])
host_compile_options = CompilationOptions(
self._nvcc_compile_options, arch, include_paths)
if compile_options is None:
compile_options = CompilationOptions(
self.default_compile_options, arch, include_paths)
# save the cubin
@@ -357,7 +384,7 @@ class ArtifactManager:
# step 1: check if the operation is in cache
compiled_kernel = self.compiled_cache_device.at(key)
if compiled_kernel is None:
if compiled_kernel is None and not bypass_cache:
hit = self.load_operation(key, getattr( operation.rt_module, "extra_funcs", {}))
if hit:
compiled_kernel = self.compiled_cache_device.at(key)
@@ -375,7 +402,7 @@ class ArtifactManager:
if len(operation_list) > 0:
cubin_image, host_lib, host_file = self.emit_compile_(
operation_list, compile_options)
operation_list, compile_options, host_compile_options)
err, module = cuda.cuModuleLoadData(cubin_image)
if err != cuda.CUresult.CUDA_SUCCESS:

View File

@@ -41,6 +41,7 @@ import numpy as np
from cutlass.backend.arguments import ArgumentBase
from cutlass.backend.c_types import Conv2DProblemSize, TensorRef_, get_conv2d_arguments
from cutlass.backend.library import (
EmissionType,
ConvKindNames,
ConvKindTag,
DataTypeNames,
@@ -123,17 +124,17 @@ class Conv2dArguments(ArgumentBase):
super().__init__(A, B, C, D, **kwargs)
# preprocessing output ops
if "output_op" in kwargs.keys() and split_k_mode != cutlass_bindings.conv.SplitKMode.Parallel:
self.output_op = kwargs["output_op"]
else:
self.output_op = self.operation.epilogue_type(1.0, 0.0)
if "split_k_slices" in kwargs.keys():
if "split_k_slices" in kwargs.keys() and kwargs["split_k_slices"] > 1:
self.split_k_mode = split_k_mode
self.split_k_slices = kwargs["split_k_slices"]
else:
self.split_k_mode = cutlass_bindings.conv.SplitKMode.Serial
self.split_k_slices = 1
if "output_op" in kwargs.keys() and self.split_k_mode != cutlass_bindings.conv.SplitKMode.Parallel:
self.output_op = kwargs["output_op"]
else:
self.output_op = self.operation.epilogue_type(1.0, 0.0)
#: problem_size
self.problem_size: cutlass_bindings.conv.Conv2dProblemSize = problem_size
@@ -419,7 +420,9 @@ class Conv2dOperation:
C: TensorDescription,
stride_support,
epilogue_functor,
swizzling_functor=cutlass_bindings.IdentitySwizzle1
swizzling_functor=cutlass_bindings.IdentitySwizzle1,
emission_type=EmissionType.Kernel,
**kwargs
):
self.operation_kind: OperationKind = OperationKind.Conv2d
self.arch: int = arch
@@ -432,6 +435,8 @@ class Conv2dOperation:
self.iterator_algorithm = iterator_algorithm
self.stride_support = stride_support
self.swizzling_functor = swizzling_functor()
self.emission_type = emission_type
self.rt_module: Conv2dRT = Conv2dRT(self)
self.argument_type = self.rt_module.argument_type
@@ -562,6 +567,18 @@ class Conv2dOperation:
return accum
def device_op(self):
"""
Returns a new Conv2dOperation object that is constructed with emission type
``EmissionType.Device``.
:return: operation ready for device-level code emission
:rtype: Conv2dOperation
"""
return Conv2dOperation(
self.conv_kind, self.iterator_algorithm, self.arch, self.tile_description,
self.A, self.B, self.C, self.stride_support, self.epilogue_functor, type(self.swizzling_functor),
emission_type=EmissionType.Device)
###################################################################################################
#
@@ -596,7 +613,7 @@ typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}<
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >,
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
${epilogue_functor},
${swizzling_functor}, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>,
${swizzling_functor},
${stages},
${math_operator},
${iterator_algorithm},
@@ -608,6 +625,36 @@ typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}<
struct ${operation_name}${operation_suffix}:
public ${operation_name}_base { };
"""
self.template_device = """
// Conv2d operation ${operation_name}
using Conv2d${conv_kind_name}Kernel = typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}<
${element_a},
${layout_a},
${element_b},
${layout_b},
${element_c},
${layout_c},
${element_accumulator},
${opcode_class},
${arch},
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >,
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
${epilogue_functor},
${swizzling_functor},
${stages},
${math_operator},
${iterator_algorithm},
${stride_support},
${align_a},
${align_b}
>::Kernel;
using DeviceKernel =
typename cutlass::conv::device::ImplicitGemmConvolution<Conv2d${conv_kind_name}Kernel>;
"""
def emit(self, operation):
@@ -651,5 +698,10 @@ struct ${operation_name}${operation_suffix}:
"align_a": str(operation.A.alignment),
"align_b": str(operation.B.alignment),
}
if operation.emission_type == EmissionType.Kernel:
conv2d_template = self.template
else:
conv2d_template = self.template_device
return SubstituteTemplate(self.template, values)
return SubstituteTemplate(conv2d_template, values)

View File

@@ -39,7 +39,17 @@ import cutlass_bindings
import numpy as np
import rmm
from cutlass import KernelScheduleSuffixes, KernelScheduleTag, KernelScheduleType
from cutlass import (
EpilogueScheduleSuffixes,
EpilogueScheduleTag,
EpilogueScheduleType,
KernelScheduleSuffixes,
KernelScheduleTag,
KernelScheduleType,
TileSchedulerSuffixes,
TileSchedulerTag,
TileSchedulerType
)
from cutlass.backend.arguments import ArgumentBase
from cutlass.backend.c_types import (
GemmCoord_,
@@ -55,6 +65,7 @@ from cutlass.backend.c_types import (
)
from cutlass.backend.library import (
ApiVersion,
EmissionType,
ComplexTransformTag,
DataTypeNames,
DataTypeSize,
@@ -548,6 +559,7 @@ class GemmArguments3x(GemmArguments2x):
stride_A,
int(self.ptr_B),
stride_B,
4 # mma_promotion_interval
)
# Set of mainloop arguments needed for this kernel
@@ -561,11 +573,15 @@ class GemmArguments3x(GemmArguments2x):
stride_D,
)
# Set hardware info
hw_info = self.operation.rt_module.hw_info(0, device_sm_count())
self.arguments = self.operation.argument_type(
self.gemm_mode,
problem_size_,
mainloop,
epilogue,
hw_info,
)
return self.arguments
@@ -1102,6 +1118,11 @@ extern "C" {
using GemmType = ${operation_name}_base;
// Get the workspace size
uint64_t ${operation_name}_get_kernel_workspace_size(GemmType::Arguments* argument) {
return GemmType::get_workspace_size(*argument);
}
// Get the params as byte array
char* ${operation_name}_get_params(GemmType::Arguments* argument, int* workspace){
GemmType::Params params = GemmType::to_underlying_arguments(*argument, workspace);
@@ -1118,7 +1139,7 @@ extern "C" {
uint64_t ${operation_name}_get_persistent_tiled_blk_shape_mnl(GemmType::ProblemShape problem) {
auto problem_shape_MNKL = append<4>(problem, Int<1>{});
auto [problem_blocks_m, problem_blocks_n, problem_blocks_l] =
cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::get_tiled_blk_shape_mnl(
cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::get_tiled_cta_shape_mnl(
problem_shape_MNKL, GemmType::TileShape{}, GemmType::DispatchPolicy::ClusterShape{});
return problem_blocks_m * problem_blocks_n * problem_blocks_l;
}
@@ -1141,7 +1162,8 @@ extern "C" {
self.extra_funcs = {
"get_grid_shape": dim3_,
"get_block_shape": dim3_,
"get_persistent_tiled_blk_shape_mnl": ctypes.c_uint64
"get_persistent_tiled_blk_shape_mnl": ctypes.c_uint64,
"get_kernel_workspace_size": ctypes.c_uint64
}
self.emitter = EmitGemmUniversalInstance3x("_type")
self.mainloop_args = get_mainloop_arguments_3x(
@@ -1151,7 +1173,10 @@ extern "C" {
operation.A.alignment,
operation.B.alignment
)
self.argument_type, self.epilogue_args, self.epilogue_type = get_gemm_arguments_3x(self.mainloop_args, operation.epilogue_functor)
self.argument_type, self.epilogue_args, self.epilogue_type, self.hw_info = get_gemm_arguments_3x(self.mainloop_args, operation.epilogue_functor)
def get_device_workspace_size(self, arguments: GemmArguments3x):
return self.get_kernel_workspace_size(ctypes.byref(arguments.get_arguments()))
class EmitGemmUniversalInstance3x:
@@ -1183,7 +1208,7 @@ using CollectiveEpilogue =
${element_accumulator}, ${element_epilogue},
${element_c}, ${layout_c}, ${align_c},
${element_d}, ${layout_d}, ${align_d},
cutlass::epilogue::collective::EpilogueScheduleAuto
${epilogue_schedule}
>::CollectiveOp;
using CollectiveMainloop =
@@ -1202,7 +1227,8 @@ using CollectiveMainloop =
using ${operation_name}_base = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue
CollectiveEpilogue,
${tile_scheduler}
>;
// Define named type
@@ -1233,9 +1259,15 @@ using DeviceKernel = cutlass::gemm::device::GemmUniversalAdapter<${operation_nam
else:
gemm_template = self.gemm_template_device
schedule = KernelScheduleType.ScheduleAuto
kschedule = KernelScheduleType.ScheduleAuto
eschedule = EpilogueScheduleType.ScheduleAuto
tschedule = TileSchedulerType.Default
if operation.tile_description.kernel_schedule is not None:
schedule = operation.tile_description.kernel_schedule
kschedule = operation.tile_description.kernel_schedule
if operation.tile_description.epilogue_schedule is not None:
eschedule = operation.tile_description.epilogue_schedule
if operation.tile_description.tile_scheduler is not None:
tschedule = operation.tile_description.tile_scheduler
values = {
"operation_name": operation.procedural_name(),
@@ -1264,7 +1296,9 @@ using DeviceKernel = cutlass::gemm::device::GemmUniversalAdapter<${operation_nam
"align_c": str(operation.C.alignment),
"align_d": str(operation.C.alignment),
"stage_count_type": stage_count_type,
"kernel_schedule": KernelScheduleTag[schedule],
"kernel_schedule": KernelScheduleTag[kschedule],
"epilogue_schedule": EpilogueScheduleTag[eschedule],
"tile_scheduler": TileSchedulerTag[tschedule]
}
values["epilogue_functor"] = operation.epilogue_functor.emit()
@@ -1382,15 +1416,6 @@ ${operation_name}(${operation_name}${operation_suffix}::Params params) {
################################################################################
class EmissionType(enum.Enum):
"""
Tags for whether to emit a kernel- or device-level operation
"""
Kernel = enum_auto()
Device = enum_auto()
class GemmOperationBase:
"""
CUTLASS GEMM operation
@@ -1595,11 +1620,18 @@ class GemmOperationBase:
else:
return KernelScheduleSuffixes[self.tile_description.kernel_schedule]
# Generates a short string representing underlying epilogue schedule type
def epilogue_schedule_name_3x(self):
if self.tile_description.epilogue_schedule is None:
return EpilogueScheduleSuffixes[EpilogueScheduleType.ScheduleAuto]
else:
return EpilogueScheduleSuffixes[self.tile_description.epilogue_schedule]
def procedural_name(self):
"""The full procedural name indicates architecture, extended name, tile size, and layout."""
opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
if self.api == ApiVersion.v3x and self.arch >= 90:
kernel_name_template = "cutlass{p}_sm{ar}_{op}_{ex}_{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{l}_{s}_align{al}{k}"
kernel_name_template = "cutlass{p}_sm{ar}_{op}_{ex}_{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{l}_{s}_align{al}{k}{e}"
return kernel_name_template.format(
p=self.prefix,
ar=self.arch,
@@ -1614,7 +1646,8 @@ class GemmOperationBase:
l=self.tile_description.stages,
s=self.layout_name_3x(),
al=str(self.A.alignment),
k=self.kernel_schedule_name_3x()
k=self.kernel_schedule_name_3x(),
e=self.epilogue_schedule_name_3x()
)
else:
threadblock = self.tile_description.procedural_name()

View File

@@ -38,7 +38,7 @@ but uses the Pybind-bound CUTLASS data types as many keys to the dictionary.
import enum
import cutlass_bindings
from cutlass import KernelScheduleType
from cutlass import EpilogueScheduleType, KernelScheduleType, TileSchedulerType
# The following block implements enum.auto() for Python 3.5 variants that don't include it such
@@ -554,7 +554,9 @@ class TileDescription:
warp_count,
math_instruction,
cluster_shape=[1, 1, 1],
kernel_schedule: KernelScheduleType = None
kernel_schedule: KernelScheduleType = None,
epilogue_schedule: EpilogueScheduleType = None,
tile_scheduler: TileSchedulerType = None,
):
"""
:param threadblock_shape: shape of a threadblock tyle
@@ -568,18 +570,61 @@ class TileDescription:
:type math_instruction: MathInstruction
:param cluster_shape: number of threadblocks in the [X, Y, Z] dimensions of a threadblock cluster
:param kernel_schedule: type of kernel schedule to use (only available for SM90+)
:type kernel_schedule: cutlass.backend.KernelScheduleType
:type kernel_schedule: cutlass.KernelScheduleType
:param epilogue_schedule: type of epilogue schedule to use (only available for SM90+)
:type epilogue_schedule: cutlass.EpilogueScheduleType
:param tile_scheduler: type of tile scheduler to use (only available for SM90+)
:type tile_scheduler: cutlass.TileSchedulerType
"""
if ((kernel_schedule is None and epilogue_schedule is not None) or
(kernel_schedule is not None and epilogue_schedule is None)):
raise Exception("Kernel and epilogue schedule must either both be Auto or neither be Auto.")
self.threadblock_shape = threadblock_shape
self.cluster_shape = cluster_shape
self.kernel_schedule = kernel_schedule
self.stages: int = stages
self.epilogue_schedule = epilogue_schedule
self.tile_scheduler = tile_scheduler
self.stages = stages
self.math_instruction = math_instruction
self.instruction_shape = math_instruction.instruction_shape
# Number of warps along x, y, z directions
self.warp_count = warp_count
def clone_and_update(self, td: dict):
attrs = {
"cluster_shape": None,
"threadblock_shape": None,
"warp_count": None,
"stages": None,
"instruction_shape": None,
"kernel_schedule": None,
"epilogue_schedule": None,
"tile_scheduler": None
}
for key in attrs.keys():
if key in td.keys():
attrs[key] = td[key]
else:
attrs[key] = getattr(self, key)
mi = MathInstruction(
attrs["instruction_shape"],
self.math_instruction.element_a,
self.math_instruction.element_b,
self.math_instruction.element_accumulator,
self.math_instruction.opcode_class,
self.math_instruction.math_operation
)
return TileDescription(
attrs["threadblock_shape"], attrs["stages"],
attrs["warp_count"], mi, attrs["cluster_shape"],
attrs["kernel_schedule"], attrs["epilogue_schedule"]
)
@property
def num_threads(self):
"""
@@ -622,16 +667,30 @@ class TileDescription:
:return: contents of tile description
:rtype: str
"""
schedule = KernelScheduleType.ScheduleAuto
if self.kernel_schedule is not None:
schedule = self.kernel_schedule
kschedule = self.kernel_schedule
else:
kschedule = KernelScheduleType.ScheduleAuto
if self.epilogue_schedule is not None:
eschedule = self.epilogue_schedule
else:
eschedule = EpilogueScheduleType.ScheduleAuto
if self.tile_scheduler is not None:
tschedule = self.tile_scheduler.name
else:
tschedule = "None"
return f"""
{{
ClusterShape: {self.cluster_shape}
ThreadblockShape: {self.threadblock_shape}
WarpCount: {self.warp_count}
Stages: {self.stages if self.stages is not None else 'Auto'}
Kernel schedule: {schedule.name}
InstructionShape: {self.math_instruction.instruction_shape}
Kernel schedule: {kschedule.name}
Epilogue schedule: {kschedule.name}
TileScheduler: {tschedule}
}}"""
@@ -712,3 +771,12 @@ def api_version(arch, opclass, datatype):
return ApiVersion.v3x
else:
return ApiVersion.v2x
class EmissionType(enum.Enum):
"""
Tags for whether to emit a kernel- or device-level operation
"""
Kernel = enum_auto()
Device = enum_auto()

View File

@@ -38,7 +38,7 @@ from bfloat16 import bfloat16
import cutlass_bindings
import numpy as np
from cutlass.backend.compiler import ArtifactManager
from cutlass.backend import compiler
from cutlass.backend.conv2d_operation import Conv2dArguments, Conv2dOperation
from cutlass.backend.library import DataTypeSize, ShortDataTypeNames, StrideSupport
from cutlass.backend.memory_manager import get_allocated_size
@@ -127,7 +127,6 @@ def getTensorView(tensor, tensor_layout, conv_kind, problem_size, operand):
raise ValueError("unsupported data type")
# @typechecked
class Conv2dLauncher:
"""
Launcher that runs the operation on given problem size
@@ -142,6 +141,7 @@ class Conv2dLauncher:
profiling=False,
warmup_iterations=500,
iterations=500,
compilation_mode="nvcc",
**kwargs,
) -> None:
self.enable_cached_results = True
@@ -176,7 +176,14 @@ class Conv2dLauncher:
# Compile the operator
#
ArtifactManager().add_module([operation, self.reduction_operation])
if compilation_mode == "nvcc":
compiler.nvcc()
elif compilation_mode == "nvrtc":
compiler.nvrtc()
else:
raise Exception(f"Unexpected compilation mode {compilation_mode}")
compiler.add_module([operation, self.reduction_operation])
self.operation = operation
@@ -195,14 +202,14 @@ class Conv2dLauncher:
element_size = DataTypeSize[operation.A.element]
if element_size <= 8:
self.scope = 1
self.randomization_max = 1
elif element_size == 16:
if accumulator_size <= 16:
self.scope = 2
self.randomization_max = 2
else:
self.scope = 4
self.randomization_max = 4
else:
self.scope = 7
self.randomization_max = 7
# Seed
self.seed = seed
@@ -263,12 +270,12 @@ class Conv2dLauncher:
if dtype in [np.float32, np.float16, bfloat16, np.float64]:
return np.ceil(
np.random.uniform(
low=-self.scope - 0.5, high=self.scope - 0.5, size=size
low=-self.randomization_max - 0.5, high=self.randomization_max - 0.5, size=size
).astype(dtype)
)
else:
return np.random.uniform(
low=-self.scope - 1, high=self.scope + 1, size=size
low=-self.randomization_max - 1, high=self.randomization_max + 1, size=size
).astype(dtype)
def eq_gemm_size(self, problem_size):
@@ -624,13 +631,15 @@ class Conv2dLauncher:
############################################################################################################
def test_all_conv2d(operation: Conv2dOperation, conv_test_sizes=[], interleaved=False):
passed = True
#
# Testbed object
#
def test_all_conv2d_from_compilation_mode(
operation: Conv2dOperation,
conv_test_sizes,
interleaved,
compilation_mode):
testbed = Conv2dLauncher(operation, interleaved=interleaved)
passed = True
testbed = Conv2dLauncher(operation, interleaved=interleaved, compilation_mode=compilation_mode)
#
# Get conv problem sizes to run conv operator
@@ -781,3 +790,18 @@ def test_all_conv2d(operation: Conv2dOperation, conv_test_sizes=[], interleaved=
)
return passed
def test_all_conv2d(
operation: Conv2dOperation,
conv_test_sizes=[],
interleaved=False,
compilation_modes=["nvcc", "nvrtc"]):
for compilation_mode in compilation_modes:
passed = test_all_conv2d_from_compilation_mode(operation, conv_test_sizes, interleaved, compilation_mode)
if not passed:
return False
return True

View File

@@ -177,6 +177,7 @@ class GemmUniversalLauncher:
profiling=False,
warmup_iterations=500,
iterations=500,
compiler_mode: str = "nvcc",
**kwargs,
) -> None:
# create the reduction kernel
@@ -209,13 +210,19 @@ class GemmUniversalLauncher:
#
# Compile the operator
#
if compiler_mode == "nvcc":
compiler.nvcc()
elif compiler_mode == "nvrtc":
compiler.nvrtc()
else:
raise Exception(f"Unexpected compiler string {compiler_mode}")
op_list = [operation]
if operation.arch < 90:
# Split K via Python is currently only supported for pre-SM90 kernels
op_list.append(self.reduction_operation)
compiler.add_module(op_list)
compiler.add_module(op_list, bypass_cache=True)
self.operation = operation
@@ -603,7 +610,7 @@ class GemmUniversalLauncher:
return passed
def test_all_gemm(operation: "GemmOperationUniversal", testcase="universal"):
def test_all_gemm(operation: "GemmOperationUniversal", testcase="universal", compilation_mode="nvcc"):
passed = True
minimum_operand_element_size = min(
@@ -711,7 +718,7 @@ def test_all_gemm(operation: "GemmOperationUniversal", testcase="universal"):
problem_alpha = [1.0]
problem_beta = [2.0]
testbed = GemmUniversalLauncher(operation, interleaved=(testcase == "interleaved"))
testbed = GemmUniversalLauncher(operation, interleaved=(testcase == "interleaved"), compiler_mode=compilation_mode)
for mode in modes:
for m in problem_size_m:

View File

@@ -30,10 +30,13 @@
#
#################################################################################################
import cutlass
import cutlass_bindings
from cutlass import KernelScheduleSuffixes
from cutlass import EpilogueScheduleSuffixes, KernelScheduleSuffixes
from cutlass.utils.datatypes import binding_opclass, binding_type
from cutlass.backend import library
from cutlass.backend.test.gemm_testbed import test_all_gemm
from cutlass.backend.utils.software import SubstituteTemplate
@@ -75,6 +78,7 @@ def get_name(
arch,
opclass,
kernel_schedule=None,
epilogue_schedule=None,
suffix="",
):
"""
@@ -97,24 +101,26 @@ def get_name(
:type opclass: cutlass_bindings.OpClass
:param kernel_schedule: kernel_schedule type
:type kernel_schedule: cutlass.KernelScheduleType
:param epilogue_schedule: epilogue_schedule type
:type epilogue_schedule: cutlass.EpilogueScheduleType
:param suffix: additional string to add to the suffix of the name
:type suffix: str
:return: str
"""
name_format = "test_SM${arch}_Device_Gemm_${eA}${lA}_${eB}${lB}_${eC}${lC}_${opclass}_${acc}_${tbM}x${tbN}x${tbK}_${cM}x${cN}x${cK}_${stages}_align${aA}-${aB}-${aC}${k}${suffix}"
name_format = "test_SM${arch}_Device_Gemm_${eA}${lA}_${eB}${lB}_${eC}${lC}_${opclass}_${acc}_${tbM}x${tbN}x${tbK}_${cM}x${cN}x${cK}_${stages}_align${aA}-${aB}-${aC}${k}${e}${suffix}"
return SubstituteTemplate(
name_format,
{
"arch": str(arch),
"eA": library.DataTypeNames[element_a],
"eB": library.DataTypeNames[element_b],
"eC": library.DataTypeNames[element_output],
"eA": library.DataTypeNames[binding_type(element_a)],
"eB": library.DataTypeNames[binding_type(element_b)],
"eC": library.DataTypeNames[binding_type(element_output)],
"lA": library.ShortLayoutTypeNames[layouts[0]],
"lB": library.ShortLayoutTypeNames[layouts[1]],
"lC": library.ShortLayoutTypeNames[layouts[2]],
"opclass": library.OpcodeClassNames[opclass],
"acc": library.DataTypeNames[element_accumulator],
"opclass": library.OpcodeClassNames[binding_opclass(opclass)],
"acc": library.DataTypeNames[binding_type(element_accumulator)],
"cM": str(cluster_shape[0]),
"cN": str(cluster_shape[1]),
"cK": str(cluster_shape[2]),
@@ -126,6 +132,174 @@ def get_name(
"aB": str(alignments[1]),
"aC": str(alignments[2]),
"k": "" if kernel_schedule is None else KernelScheduleSuffixes[kernel_schedule],
"e": "" if epilogue_schedule is None else EpilogueScheduleSuffixes[epilogue_schedule],
"suffix": "" if suffix is None else suffix,
},
)
def get_name_conv2d(
arch,
conv_kind,
element,
element_accumulator,
element_output,
opclass,
threadblock_shape,
warp_count,
instruction_shape,
stages,
iterator_algorithm,
swizzle,
split_k_mode,
split_k_slices,
activation
):
"""
Generates a procedural name for a test case for conv2d
:param arch: compute capability of kernel being generated
:type arch: int
:param conv_kind: the convolution type (i.e. fprop, dgrad, wgrad)
:type conv_kind: str
:param iterator_algorithm: the iterator algorithm applied
:type iterator_algorithm: cutlass_bindings.conv.IteratorAlgorithm
:param element_a: data type of operand A
:param element_b: data type of operand B
:param element_c: data type of operand C
:param element_accumulator: data type used in accumulation
:param opclass: class of operation being performed (e.g., SIMT, Tensor Core)
:type opclass: cutlass_bindings.OpClass
:param threadblock_shape: indexable container of dimensions of threadblock tiles
:param stages: number of pipeline stages to use in the kernel
:type stages: int
:param stride_support: stride support of dgrad
:param alignment: int
:type alignment: int
:return: str
"""
if iterator_algorithm is None:
iterator_algorithm = "AUTO"
if swizzle is None:
swizzle = 1
name_format = "test_SM${arch}_Device_Conv2d_${conv_kind}_${iter_alg}_ImplicitGemm_${eA}nhwc_${eB}nhwc_${eC}nhwc_${opclass}_${acc}_${tbM}x${tbN}x${tbK}_${wM}x${wN}x${wK}_${IM}${IN}${IK}_stage${stages}_swizzle${swizzle}_${split_k_mode}${split_k_slices}_${activation}"
return SubstituteTemplate(
name_format,
{
"arch": str(arch),
"conv_kind": conv_kind,
"iter_alg": iterator_algorithm,
"eA": library.DataTypeNames[binding_type(element)],
"eB": library.DataTypeNames[binding_type(element)],
"eC": library.DataTypeNames[binding_type(element_output)],
"opclass": opclass,
"acc": library.DataTypeNames[binding_type(element_accumulator)],
"tbM": str(threadblock_shape[0]),
"tbN": str(threadblock_shape[1]),
"tbK": str(threadblock_shape[2]),
"wM": str(threadblock_shape[0] // warp_count[0]),
"wN": str(threadblock_shape[1] // warp_count[1]),
"wK": str(threadblock_shape[2] // warp_count[2]),
"IM": str(instruction_shape[0]),
"IN": str(instruction_shape[1]),
"IK": str(instruction_shape[2]),
"stages": str(stages),
"swizzle": str(swizzle),
"split_k_mode": split_k_mode,
"split_k_slices": str(split_k_slices),
"activation": activation
}
)
def add_test_gemm(
cls=None,
cc=None,
element=None,
layouts=None,
alignments=None,
element_output=None,
element_accumulator=None,
cluster_shape=None,
threadblock_shape=None,
warp_count=None,
stages=None,
opclass=None,
swizzle=None,
kernel_schedule=None,
epilogue_schedule=None,
compilation_modes=['nvcc', 'nvrtc']):
"""
Create test-running functions with the given specification and set it as a method of ``cls``.
:param cls: class to which the generated method will be added
:type cls: type
:param cc: compute capability to compile for
:type cc: int
:param element: data type of A and B operands
:type element: cutlass.DataType.f16
:param layouts: layouts of A, B, and C operands
:type layouts: list or tuple
:param alignments: alingments of A, B, and C operands
:type alignments: list or tuple
:param element_output: data type of the output element
:type element_output: cutlass.DataType
:param element_accumulator: data type used in accumulation
:type element_accumulator: cutlass.DataType
:param cluster_shape: dimensions of clusters
:type cluster_shape: list or tuple
:param threadblock_shape: dimensions of threadblock tiles
:type threadblock_shape: list or tuple
:param warp_count: warps to be launched per threadblock dimension
:type warp_count: list or tuple
:param stages: number of pipeline stages to use in the kernel
:type stages: int
:param opclass: class of operation being performed (e.g., SIMT, Tensor Core)
:type opclass: cutlass.OpClass
:param swizzle: threadblock swizzling functor
:param kernel_schedule: kernel schedule to use
:type kernel_schedule: cutlass.KernelScheduleType
:param epilogue_schedule: epilogue schedule to use
:type epilogue_schedule: cutlass.EpilogueScheduleType
:param compilation_modes: list of compilers to used in testing the kernel (options: 'nvrtc', 'nvcc')
:type compilation_modes: list
"""
for compilation_mode in compilation_modes:
def run(self):
"""
Dynamically-generated function that constructs a GEMM operation and verifies it against
multiple test cases.
"""
element_A = element
element_B = element
layout_A, layout_B, layout_C = layouts
alignment_A, alignment_B, alignment_C = alignments
plan = cutlass.op.Gemm(element_A=element_A, element_B=element_B,
element_C=element_output, element_D=element_output,
layout_A=layout_A, layout_B=layout_B, layout_C=layout_C,
element_accumulator=element_accumulator,
kernel_cc=cc)
plan.opclass = opclass
if swizzle is not None:
plan.swizzling_functor = swizzle
td = plan.tile_descriptions()[0]
td.threadblock_shape = threadblock_shape
td.stages = stages
if warp_count is not None:
td.warp_count = warp_count
td.cluster_shape = cluster_shape
op = plan.construct(tile_description=td, alignment_A=alignment_A, alignment_B=alignment_B, alignment_C=alignment_C)
self.assertTrue(test_all_gemm(op, 'universal', compilation_mode=compilation_mode))
element_epilogue = element_accumulator
name = get_name(
layouts=layouts, alignments=alignments, element_output=element_output, element_accumulator=element_accumulator,
element_epilogue=element_epilogue, cluster_shape=cluster_shape, threadblock_shape=threadblock_shape,
stages=stages, element_a=element, element_b=element, arch=cc, opclass=opclass,
kernel_schedule=kernel_schedule, epilogue_schedule=epilogue_schedule, suffix=f'_{compilation_mode}')
setattr(cls, name, run)

View File

@@ -135,9 +135,8 @@ void bind_dgrad_swizzle(py::module & m, std::string name) {
:param problem_size: Implicit gemm problem size conv_operator(NPQK, NHWC, KRSC)
:type problem_size: :class:`cutlass.gemm.GemmCoord`)
)pbdoc")
.def("get_grid_shape", [](const T & swizzle, cutlass::gemm::GemmCoord tiled_shape) {
return dim3(tiled_shape.m(), tiled_shape.n(), tiled_shape.k());
}, py::arg("tiled_shape"),
.def("get_grid_shape", &T::get_grid_shape,
py::arg("tiled_shape"),
R"pbdoc(Computes CUDA grid dimensions given a size in units of logical tiles)pbdoc")
.def("tag", [](const T & swizzle){
return demangle(typeid(T).name());

View File

@@ -155,7 +155,8 @@ void bind_conv_host_references(py::module &m) {
/// Cache
py::class_<test::conv::device::CachedTestKey>(m, "CachedTestKey")
.def(py::init<>())
.def(py::init<std::string, std::string, std::string, uint32_t, uint32_t, uint32_t>());
.def(py::init<std::string, std::string, std::string, uint32_t, uint32_t, uint32_t>())
.def_readwrite("problem", &test::conv::device::CachedTestKey::problem);
py::class_<test::conv::device::CachedTestResult>(m, "CachedTestResult")
.def(py::init<>())

View File

@@ -117,16 +117,16 @@ cutlass::Status ${name}_kernel_run(
typename DeviceKernel::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{M, N, K, L}, // problem size
A, // ptrA
make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)), // stride A
B, // ptrB
make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)), // stride B
{M, N, K, L}, // problem size
A, // ptrA
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)), // stride A
B, // ptrB
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)), // stride B
{
C, // ptrC
make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)), // stride C
D, // ptrD
make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)), // stride D
C, // ptrC
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)), // stride C
D, // ptrD
cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)), // stride D
{alpha, beta},
},
hw_info
@@ -180,3 +180,86 @@ cutlass::Status ${name}_kernel_run(int problem_count, cutlass::gemm::GemmCoord*
return status;
}
"""
_CUTLASS_KERNEL_RUN_CONV2D_2x = """
using UnderlyingKernel = typename DeviceKernel::UnderlyingKernel;
namespace {
using TensorRefA = typename UnderlyingKernel::TensorRefA;
using TensorRefB = typename UnderlyingKernel::TensorRefB;
using TensorRefC = typename UnderlyingKernel::TensorRefC;
using ElementCompute = typename UnderlyingKernel::EpilogueOutputOp::ElementCompute;
}
template<typename TensorRef, typename Element>
TensorRef get_tensor_ref(cutlass::Tensor4DCoord tensor_coord, Element* ptr){
cutlass::layout::TensorNHWC layout = cutlass::layout::TensorNHWC::packed(tensor_coord);
TensorRef tensor_ref(ptr, layout);
return tensor_ref;
}
cutlass::Status ${name}_kernel_run(cutlass::conv::Conv2dProblemSize* problem_size,
UnderlyingKernel::ElementA* A, UnderlyingKernel::ElementB* B,
UnderlyingKernel::ElementC* C, UnderlyingKernel::ElementC* D,
ElementCompute alpha, ElementCompute beta, std::string split_k_mode,
cudaStream_t stream, int device_id=0) {
// create the tensor references
cutlass::Tensor4DCoord tensor_coord_A = cutlass::conv::implicit_gemm_tensor_a_extent(
cutlass::conv::Operator::k${conv_kind_name}, *problem_size
);
cutlass::Tensor4DCoord tensor_coord_B = cutlass::conv::implicit_gemm_tensor_b_extent(
cutlass::conv::Operator::k${conv_kind_name}, *problem_size
);
cutlass::Tensor4DCoord tensor_coord_C = cutlass::conv::implicit_gemm_tensor_c_extent(
cutlass::conv::Operator::k${conv_kind_name}, *problem_size
);
TensorRefA tensor_ref_A = get_tensor_ref<TensorRefA, UnderlyingKernel::ElementA>(tensor_coord_A, A);
TensorRefB tensor_ref_B = get_tensor_ref<TensorRefB, UnderlyingKernel::ElementB>(tensor_coord_B, B);
TensorRefC tensor_ref_C = get_tensor_ref<TensorRefC, UnderlyingKernel::ElementC>(tensor_coord_C, C);
TensorRefC tensor_ref_D = get_tensor_ref<TensorRefC, UnderlyingKernel::ElementC>(tensor_coord_C, D);
cutlass::conv::SplitKMode mode;
if (split_k_mode == "serial") {
mode = cutlass::conv::SplitKMode::kSerial;
} else if (split_k_mode == "parallel") {
mode = cutlass::conv::SplitKMode::kParallel;
} else {
throw std::runtime_error("Invalid split_k_mode: " + split_k_mode);
}
typename DeviceKernel::Arguments arguments{
*problem_size,
tensor_ref_A,
tensor_ref_B,
tensor_ref_C,
tensor_ref_D,
{alpha, beta},
mode
};
DeviceKernel implicit_gemm_op;
size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments);
void* workspace_ptr = device_memory_allocation(workspace_size, device_id);
cutlass::Status status = implicit_gemm_op.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
return status;
}
status = implicit_gemm_op.initialize(arguments, workspace_ptr, stream);
if (status != cutlass::Status::kSuccess) {
return status;
}
//
// Launch initialized CUTLASS kernel
//
status = implicit_gemm_op(stream);
return status;
}
"""

View File

@@ -85,7 +85,8 @@ import cutlass_bindings
from cutlass import CUTLASS_PATH, logger, swizzle
from cutlass.backend.gemm_operation import GemmOperationGrouped, GemmOperationUniversal
from cutlass.backend.library import ApiVersion
from cutlass.backend.conv2d_operation import Conv2dOperation
from cutlass.backend.library import ApiVersion, ConvKindNames
from cutlass.backend.utils.software import CheckPackages, SubstituteTemplate
from cutlass.emit import common
@@ -95,12 +96,26 @@ if torch_available:
_PYTORCH_CUDA_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """
#include <cuda_runtime.h>
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include "cutlass/cutlass.h"
#include "cutlass/util/device_memory.h"
// helper function allocating the memory
void* device_memory_allocation(size_t size, int device_id=0) {
if (size > 0) {
torch::Device device(torch::kCUDA, device_id);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
torch::TensorOptions options = torch::TensorOptions().dtype(torch::kI8).device(device);
at::Tensor device_tensor = torch::empty({(long)size,}, options);
return reinterpret_cast<void*>(device_tensor.data_ptr());
} else {
return nullptr;
}
}
${includes}
${declaration}
${impl}
@@ -143,6 +158,72 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
}
"""
_PYTORCH_CONV2D_FPROP_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <pybind11/stl.h>
// CUDA forward declarations
at::Tensor ${name}_kernel(
const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
float alpha=1.f, float beta=0.f,
std::string split_k_mode="serial", int split_k_slices=1);
// C++ interface
at::Tensor ${name}(
const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
float alpha=1.f, float beta=0.f,
std::string split_k_mode="serial", int split_k_slices=1) {
return ${name}_kernel(A, B, C, stride, padding, dilation, alpha, beta, split_k_mode, split_k_slices);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("run",
py::overload_cast<
const at::Tensor&, const at::Tensor&, at::optional<const at::Tensor>,
std::tuple<int, int>, std::tuple<int, int>, std::tuple<int, int>, float, float, std::string, int>(
&${name}), py::arg("A"), py::arg("B"), py::arg("C") = nullptr,
py::arg("stride") = std::make_tuple(1, 1), py::arg("padding") = std::make_tuple(1, 1), py::arg("dilation") = std::make_tuple(1, 1),
py::arg("alpha") = 1.f, py::arg("beta") = 0.f,
py::arg("split_k_mode") = "serial", py::arg("split_k_slices") = 1);
}
"""
_PYTORCH_CONV2D_GRAD_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <pybind11/stl.h>
// CUDA forward declarations
at::Tensor ${name}_kernel(
std::tuple<int, int, int, int> result_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
float alpha=1.f, float beta=0.f,
std::string split_k_mode="serial", int split_k_slices=1);
// C++ interface
at::Tensor ${name}(
std::tuple<int, int, int, int> result_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
float alpha=1.f, float beta=0.f,
std::string split_k_mode="serial", int split_k_slices=1) {
return ${name}_kernel(result_size, A, B, C, stride, padding, dilation, alpha, beta, split_k_mode, split_k_slices);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("run",
py::overload_cast<
std::tuple<int, int, int, int>, const at::Tensor&, const at::Tensor&, at::optional<const at::Tensor>,
std::tuple<int, int>, std::tuple<int, int>, std::tuple<int, int>, float, float, std::string, int>(
&${name}), py::arg("result_size"), py::arg("A"), py::arg("B"), py::arg("C") = nullptr,
py::arg("stride") = std::make_tuple(1, 1), py::arg("padding") = std::make_tuple(1, 1), py::arg("dilation") = std::make_tuple(1, 1),
py::arg("alpha") = 1.f, py::arg("beta") = 0.f,
py::arg("split_k_mode") = "serial", py::arg("split_k_slices") = 1);
}
"""
_PYTORCH_GEMM_INCLUDES = {
ApiVersion.v2x: """
#include "cutlass/gemm/device/gemm_universal.h"
@@ -162,6 +243,13 @@ _PYTORCH_GROUPED_GEMM_INCLUDES = """
#include "cutlass/gemm/device/gemm_grouped.h"
"""
_PYTORCH_CONV2D_INCLUDES = """
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
#include "cutlass/conv/kernel/default_conv2d_dgrad.h"
#include "cutlass/conv/kernel/default_conv2d_wgrad.h"
#include "cutlass/conv/device/implicit_gemm_convolution.h"
"""
_CUTLASS_TYPE_TO_TORCH_TYPE = {
cutlass_bindings.float16: "torch::kF16",
cutlass_bindings.float32: "torch::kF32",
@@ -356,6 +444,133 @@ std::vector<at::Tensor> ${name}_kernel(const std::vector<at::Tensor>& A, const s
"""
)
_PYTORCH_CONV2D_IMPL_TEMPLATE_2x = """
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
cutlass::Status status = ${name}_kernel_run(
&problem_size,
reinterpret_cast<typename UnderlyingKernel::ElementA*>(A.data_ptr()),
reinterpret_cast<typename UnderlyingKernel::ElementB*>(B.data_ptr()),
ptrC,
reinterpret_cast<typename UnderlyingKernel::ElementC*>(D.data_ptr()),
alpha, beta,
split_k_mode, stream, B.device().index());
TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");
return D;
}
"""
_PYTORCH_CONV2D_FPROP_IMPL_TEMPLATE_2x = (
common._CUTLASS_KERNEL_RUN_CONV2D_2x
+ """
at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
float alpha=1.f, float beta=0.f, std::string split_k_mode="serial", int split_k_slices=1) {
int N, H, W, C_, K, R, S, P, Q;
N = A.size(0);
C_ = A.size(1);
H = A.size(2);
W = A.size(3);
K = B.size(0);
R = B.size(2);
S = B.size(3);
cutlass::conv::Conv2dProblemSize problem_size(
cutlass::Tensor4DCoord(N, H, W, C_),
cutlass::Tensor4DCoord(K, R, S, C_),
cutlass::Tensor4DCoord(std::get<0>(padding), std::get<0>(padding), std::get<1>(padding), std::get<1>(padding)),
cutlass::MatrixCoord(std::get<0>(stride), std::get<1>(stride)),
cutlass::MatrixCoord(std::get<0>(dilation), std::get<1>(dilation)),
cutlass::conv::Mode::kCrossCorrelation,
split_k_slices
);
P = problem_size.P;
Q = problem_size.Q;
typename UnderlyingKernel::ElementC* ptrC = (C == at::nullopt) ?
nullptr :
reinterpret_cast<typename UnderlyingKernel::ElementC*>(C->data_ptr());
torch::TensorOptions options = torch::TensorOptions().dtype(${torch_type_C}).device(B.device()).memory_format(at::MemoryFormat::ChannelsLast);
at::Tensor D = torch::zeros({N, K, P, Q}, options);
""" + _PYTORCH_CONV2D_IMPL_TEMPLATE_2x
)
_PYTORCH_CONV2D_DGRAD_IMPL_TEMPLATE_2x = (
common._CUTLASS_KERNEL_RUN_CONV2D_2x
+ """
at::Tensor ${name}_kernel(std::tuple<int, int, int, int> input_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1}, float alpha=1.f, float beta=0.f,
std::string split_k_mode="serial", int split_k_slices=1) {
int N, H, W, C_, K, R, S;
N = std::get<0>(input_size);
C_ = std::get<1>(input_size);
H = std::get<2>(input_size);
W = std::get<3>(input_size);
K = B.size(0);
R = B.size(2);
S = B.size(3);
cutlass::conv::Conv2dProblemSize problem_size(
cutlass::Tensor4DCoord(N, H, W, C_),
cutlass::Tensor4DCoord(K, R, S, C_),
cutlass::Tensor4DCoord(std::get<0>(padding), std::get<0>(padding), std::get<1>(padding), std::get<1>(padding)),
cutlass::MatrixCoord(std::get<0>(stride), std::get<1>(stride)),
cutlass::MatrixCoord(std::get<0>(dilation), std::get<1>(dilation)),
cutlass::conv::Mode::kCrossCorrelation,
split_k_slices
);
typename UnderlyingKernel::ElementC* ptrC = (C == at::nullopt) ?
nullptr :
reinterpret_cast<typename UnderlyingKernel::ElementC*>(C->data_ptr());
torch::TensorOptions options = torch::TensorOptions().dtype(${torch_type_C}).device(B.device()).memory_format(at::MemoryFormat::ChannelsLast);
at::Tensor D = torch::empty({N, C_, H, W}, options);
""" + _PYTORCH_CONV2D_IMPL_TEMPLATE_2x
)
_PYTORCH_CONV2D_WGRAD_IMPL_TEMPLATE_2x = (
common._CUTLASS_KERNEL_RUN_CONV2D_2x
+ """
at::Tensor ${name}_kernel(std::tuple<int, int, int, int> weight_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1}, float alpha=1.f, float beta=0.f,
std::string split_k_mode="serial", int split_k_slices=1) {
int N, H, W, C_, K, R, S;
K = std::get<0>(weight_size);
C_ = std::get<1>(weight_size);
R = std::get<2>(weight_size);
S = std::get<3>(weight_size);
N = B.size(0);
H = B.size(2);
W = B.size(3);
cutlass::conv::Conv2dProblemSize problem_size(
cutlass::Tensor4DCoord(N, H, W, C_),
cutlass::Tensor4DCoord(K, R, S, C_),
cutlass::Tensor4DCoord(std::get<0>(padding), std::get<0>(padding), std::get<1>(padding), std::get<1>(padding)),
cutlass::MatrixCoord(std::get<0>(stride), std::get<1>(stride)),
cutlass::MatrixCoord(std::get<0>(dilation), std::get<1>(dilation)),
cutlass::conv::Mode::kCrossCorrelation,
split_k_slices
);
typename UnderlyingKernel::ElementC* ptrC = (C == at::nullopt) ?
nullptr :
reinterpret_cast<typename UnderlyingKernel::ElementC*>(C->data_ptr());
torch::TensorOptions options = torch::TensorOptions().dtype(${torch_type_C}).device(B.device()).memory_format(at::MemoryFormat::ChannelsLast);
at::Tensor D = torch::empty({K, C_, R, S}, options);
""" + _PYTORCH_CONV2D_IMPL_TEMPLATE_2x
)
_PYTORCH_SETUP_PY = common._PYSTYLE_AUTOGEN_COMMENT + """
from setuptools import setup
@@ -607,6 +822,73 @@ def _pytorch_grouped_gemm(
return None
def _pytorch_conv2d(op, name: str, cc: int, jit: bool = False, sourcedir: str = ""):
"""
Generates source for building a PyTorch CUDA module that leverages the CUTLASS Conv2d
specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time
compiled, loaded, and returned.
:param op: operation to emit in the module
:param name: name of the module to generate
:type name: str
:param cc: compute capability of the device the module should target
:type cc: int
:param jit: whether the module should be just-in-time compiled
:type jit: bool
:param sourcedir: directory to which generated source files should be written
:type sourcedir: str
Note that the when conv kind is `dgrad` or `wgrad`, the size of the input `(N, C, H, W)` or
weight `(K, C, R, S)` should be provided. This is because there are multiple valid solutions
for H/W/R/S given the same P/Q.
:return: loaded PyTorch module if ``jit=True`` or ``None`` otherwise
"""
if sourcedir != "" and not os.path.isdir(sourcedir):
os.makedirs(sourcedir)
cuda_file = os.path.join(sourcedir, name + "_kernel.cu")
extra_kw = {}
if op.conv_kind == cutlass_bindings.conv.Operator.fprop:
impl_template = _PYTORCH_CONV2D_FPROP_IMPL_TEMPLATE_2x
cpp_template = _PYTORCH_CONV2D_FPROP_CPP_TEMPLATE
elif op.conv_kind == cutlass_bindings.conv.Operator.dgrad:
impl_template = _PYTORCH_CONV2D_DGRAD_IMPL_TEMPLATE_2x
cpp_template = _PYTORCH_CONV2D_GRAD_CPP_TEMPLATE
elif op.conv_kind == cutlass_bindings.conv.Operator.wgrad:
impl_template = _PYTORCH_CONV2D_WGRAD_IMPL_TEMPLATE_2x
cpp_template = _PYTORCH_CONV2D_GRAD_CPP_TEMPLATE
extra_kw["conv_kind_name"] = ConvKindNames[op.conv_kind].capitalize()
extra_kw["torch_type_C"] = _CUTLASS_TYPE_TO_TORCH_TYPE[op.C.element]
cuda_impl = SubstituteTemplate(impl_template, {"name": name, **extra_kw})
cuda_source = SubstituteTemplate(
_PYTORCH_CUDA_TEMPLATE,
{
"includes": _PYTORCH_CONV2D_INCLUDES,
"declaration": op.rt_module.emit(),
"procedural_name": op.procedural_name(),
"impl": cuda_impl,
"torch_type_C": _CUTLASS_TYPE_TO_TORCH_TYPE[op.C.element],
},
)
with open(cuda_file, "w") as outfile:
outfile.write(cuda_source)
cpp_file = os.path.join(sourcedir, name + ".cpp")
cpp_source = SubstituteTemplate(
cpp_template,
{"name": name, "description": f"CUTLASS {op.procedural_name()} Conv2d"},
)
with open(cpp_file, "w") as outfile:
outfile.write(cpp_source)
_generate_setup(name, sourcedir)
if jit:
return _jit(name, cc, cpp_file, cuda_file)
return None
def pytorch(op, name: str, cc: int, jit: bool = False, sourcedir: str = ""):
"""
Generates source for building a PyTorch CUDA module that leverages the CUTLASS kernel
@@ -633,6 +915,8 @@ def pytorch(op, name: str, cc: int, jit: bool = False, sourcedir: str = ""):
return _pytorch_gemm(device_op, name, cc, jit, sourcedir)
elif isinstance(op, GemmOperationGrouped):
return _pytorch_grouped_gemm(device_op, name, cc, jit, sourcedir)
elif isinstance(op, Conv2dOperation):
return _pytorch_conv2d(device_op, name, cc, jit, sourcedir)
else:
raise Exception(
f"Operation type {type(op)} is not currently supported for PyTorch emission."

View File

@@ -43,6 +43,9 @@ _cuda_version = __version__.split("rc")[0]
# Imports from CUTLASS profiler generator and manifest scripts
import generator as prof_generator
import manifest as prof_manifest
from library import (
ConvKind, IteratorAlgorithm, StrideSupport, GroupMode
)
import cutlass
from cutlass.utils.check import valid_stage_count
@@ -132,6 +135,8 @@ class KernelsForDataType:
ld = shape[0]
elif layout == cutlass.LayoutType.RowMajor:
ld = shape[1]
elif layout == cutlass.LayoutType.TensorNHWC:
ld = shape[-1]
else:
raise Exception(f"Unexpected or unsupported layout {layout}")
@@ -222,8 +227,9 @@ class ArchOptions:
# find available opclasses and data types
for name, op_list in manifest.operations[operation_kind].items():
for op in op_list:
if op.gemm_kind not in gemm_kinds:
continue
if operation_kind == cutlass.OperationKind.Gemm:
if op.gemm_kind not in gemm_kinds:
continue
mi = op.tile_description.math_instruction
if mi.math_operation not in self.allowed_math_operations:
@@ -276,21 +282,36 @@ class ArchOptions:
if cutlass.OpcodeClass.Simt not in self.operations_by_opclass:
self.operations_by_opclass[cutlass.OpcodeClass.Simt] = {}
types = [
(cutlass.DataType.s8, cutlass.DataType.s8, cutlass.DataType.s8),
(cutlass.DataType.s8, cutlass.DataType.s8, cutlass.DataType.s32),
(cutlass.DataType.f16, cutlass.DataType.f16, cutlass.DataType.f16),
(cutlass.DataType.f16, cutlass.DataType.f16, cutlass.DataType.f32),
(cutlass.DataType.f32, cutlass.DataType.f32, cutlass.DataType.f32),
(cutlass.DataType.f64, cutlass.DataType.f64, cutlass.DataType.f64),
]
if operation_kind == cutlass.OperationKind.Gemm:
types = [
(cutlass.DataType.s8, cutlass.DataType.s8, cutlass.DataType.s8),
(cutlass.DataType.s8, cutlass.DataType.s8, cutlass.DataType.s32),
(cutlass.DataType.f16, cutlass.DataType.f16, cutlass.DataType.f16),
(cutlass.DataType.f16, cutlass.DataType.f16, cutlass.DataType.f32),
(cutlass.DataType.f32, cutlass.DataType.f32, cutlass.DataType.f32),
(cutlass.DataType.f64, cutlass.DataType.f64, cutlass.DataType.f64),
]
layouts = [
(cutlass.LayoutType.RowMajor, cutlass.LayoutType.RowMajor),
(cutlass.LayoutType.RowMajor, cutlass.LayoutType.ColumnMajor),
(cutlass.LayoutType.ColumnMajor, cutlass.LayoutType.RowMajor),
(cutlass.LayoutType.ColumnMajor, cutlass.LayoutType.ColumnMajor),
]
elif operation_kind == cutlass.OperationKind.Conv2d:
types = [
(cutlass.DataType.f16, cutlass.DataType.f16, cutlass.DataType.f16),
(cutlass.DataType.f16, cutlass.DataType.f16, cutlass.DataType.f32),
(cutlass.DataType.f32, cutlass.DataType.f32, cutlass.DataType.f32),
(cutlass.DataType.f64, cutlass.DataType.f64, cutlass.DataType.f64),
]
layouts = [
(cutlass.LayoutType.TensorNHWC, cutlass.LayoutType.TensorNHWC),
]
else:
raise NotImplementedError(f"Operation kind {operation_kind} is currently unsupported.")
layouts = [
(cutlass.LayoutType.RowMajor, cutlass.LayoutType.RowMajor),
(cutlass.LayoutType.RowMajor, cutlass.LayoutType.ColumnMajor),
(cutlass.LayoutType.ColumnMajor, cutlass.LayoutType.RowMajor),
(cutlass.LayoutType.ColumnMajor, cutlass.LayoutType.ColumnMajor),
]
alignment = 1
epilogue_functor = cutlass.EpilogueFunctor.LinearCombination
swizzling_functor = cutlass.SwizzlingFunctor.Identity8
@@ -319,12 +340,22 @@ class ArchOptions:
if not valid_stage_count(target_cc, td_from_profiler_td(td))[0]:
continue
new_operation = prof_manifest.GemmOperation(
cutlass.GemmKind.Universal, td.minimum_compute_capability,
td, A, B, C, type_comb[2], epilogue_functor, swizzling_functor)
new_kernels = KernelsForDataType(type_comb, layout_comb)
new_kernels.add(new_operation)
if operation_kind == cutlass.OperationKind.Gemm:
new_operation = prof_manifest.GemmOperation(
cutlass.GemmKind.Universal, td.minimum_compute_capability,
td, A, B, C, type_comb[2], epilogue_functor, swizzling_functor)
new_kernels.add(new_operation)
elif operation_kind == cutlass.OperationKind.Conv2d:
for conv_kind in [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad]:
new_operation = prof_manifest.Conv2dOperation(
conv_kind, IteratorAlgorithm.Analytic, td.minimum_compute_capability, td,
A, B, C, type_comb[2], StrideSupport.Strided, epilogue_functor, swizzling_functor,
group_mode=GroupMode.SingleGroup
)
new_kernels.add(new_operation)
self.operations_by_opclass[cutlass.OpcodeClass.Simt][comb] = new_kernels
# Sort all operations
@@ -437,9 +468,12 @@ class OptionRegistry:
self.registry = {}
gemm_kinds = [cutlass.GemmKind.Universal, cutlass.GemmKind.Universal3x]
operation_kinds = [cutlass.OperationKind.Gemm, cutlass.OperationKind.Conv2d]
# Construct options for each CC
for kernel_cc in _generator_ccs:
self.registry[kernel_cc] = ArchOptions(target_cc, kernel_cc, cutlass.OperationKind.Gemm, gemm_kinds)
self.registry[kernel_cc] = {}
for opkind in operation_kinds:
self.registry[kernel_cc][opkind] = ArchOptions(target_cc, kernel_cc, opkind, gemm_kinds)
def options_for_cc(self, cc: int) -> ArchOptions:
return self.registry.get(cc, None)
def options_for_cc(self, cc: int, op_kind=cutlass.OperationKind.Gemm) -> ArchOptions:
return self.registry.get(cc, None)[op_kind]

View File

@@ -31,5 +31,6 @@
#################################################################################################
from cutlass.op.gemm import Gemm
from cutlass.op.conv import Conv2d, Conv2dFprop, Conv2dDgrad, Conv2dWgrad
from cutlass.op.gemm_grouped import GroupedGemm
from cutlass.op.op import OperationBase

960
python/cutlass/op/conv.py Normal file
View File

@@ -0,0 +1,960 @@
#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Ease-of-use interface for constructing, compiling, and running CONVs
The ``Conv2d`` interface is meant to allow one to easily instantiate, compile, and run
CONV2D operations in CUTLASS via Python, without specifying many configuration parameters.
Under the hood, the interface will select sensible default parameters for the many template
parameters for CUTLASS CONVs.
Note: optimal performance is not to be expected from this interface. To achieve optimal
performance, one should specify and tune each configuration parameter.
The simplest example of using this interface is the following:
.. highlight:: python
.. code-block:: python
# A, B, C, and D are torch/numpy/cupy tensor objects
plan = cutlass.op.Conv(A, B, C, D)
plan.run(stride=(1, 1), padding=(0, 0), dilation=(1, 1))
One can also use the interface by specifying data types of operands at construction
and using different tensor objects with these data types at runtime:
.. highlight:: python
.. code-block:: python
# The following is shorthand for:
# cutlass.op.Conv2d(kind="fprop",
# element_A=torch.float32, element_B=torch.float32,
# element_C=torch.float32, element_D=torch.float32,
# element_accumulator=torch.float32)
plan = cutlass.op.Conv2d(kind="fprop", element=torch.float32)
A0 = torch.rand((128, 256), dtype=torch.float32, device='cuda')
B0 = torch.rand((256, 64), dtype=torch.float32, device='cuda')
C0 = torch.zeros((128, 64), dtype=torch.float32, device='cuda')
D0 = torch.zeros((128, 64), dtype=torch.float32, device.'cuda')
plan.run(A0, B0, C0, D0, stride=(1, 1), padding=(0, 0), dilation=(1, 1))
A = torch.rand((32, 128), dtype=torch.float32, device='cuda')
B = torch.rand((128, 256), dtype=torch.float32, device='cuda')
C = torch.zeros((32, 256), dtype=torch.float32, device='cuda')
D = torch.zeros((32, 256), dtype=torch.float32, device.'cuda')
plan.run(A1, B1, C1, D1, stride=(1, 1), padding=(0, 0), dilation=(1, 1))
The interface additionally enables one to decouple the compilation of the underlying CUTLASS
kernel from its execution:
.. highlight:: python
.. code-block:: python
plan = cutlass.op.Conv2d(kind="fprop", element=np.float32)
# Do other work...
plan.run(A0, B0, C0, D0, stride=(1, 1), padding=(0, 0), dilation=(1, 1))
# Do other work...
plan.run(A1, B1, C1, D1, stride=(1, 1), padding=(0, 0), dilation=(1, 1))
Elementwise activation functions are easily fused to the GEMM via the interface:
.. highlight:: python
.. code-block:: python
plan = cutlass.op.Conv2d(kind="fprop", element=np.float32)
plan.activation = cutlass.epilogue.relu
Operations can also be run asynchronously:
.. highlight:: python
.. code-block:: python
plan = cutlass.op.Conv2d(kind="fprop", element=np.float32)
args = plan.run()
# Do other work...
args.sync()
"""
import cutlass_bindings
import cutlass
from cutlass import epilogue
from cutlass.backend import compiler
from cutlass.backend.conv2d_operation import Conv2dArguments, Conv2dOperation
from cutlass.backend.reduction_operation import ReductionOperation, ReductionArguments
from cutlass.backend.library import TensorDescription, TileDescription
from cutlass.op.op import OperationBase
from cutlass.utils import check, datatypes
class Conv2d(OperationBase):
"""
Constructs a ``Conv2d`` object.
The convolution kind (fprop, wgrad, degrad), the data types of operands A, B, and C,
along with the data type of output D and that used for accumulation, are bound to the ``Conv``
object throughout its lifetime -- these are not to be changed after a ``Conv2d`` has been constructed.
The constructor has optional parameters for flexibly setting these parameters. The following
constructors are equivalent:
.. highlight:: python
.. code-block:: python
# Use F32 for A, B, C, D, and accumulation in fprop
# Use the generic ``element`` parameter to concisely set all data types for operands to the same values.
Conv2d(kind="fprop", element=cutlass.DataType.f32)
# Explicitly specify the data types to use for A, B, C, and D.
Conv2d(kind="fprop", element_A=cutlass.DataType.f32, element_B=cutlass.DataType.f32,
element_C=cutlass.DataType.f32, element_D=cutlass.DataType.f32)
# Set the data types and elements from existing tensors. Note that one can use different tensors when
# executing GEMM via the ``run()`` method than passed in here (though those passed in to ``run()`` must
# have the same data type as those passed in here).
# A, B, C, and D are torch.Tensor objects of type torch.float32 under the channel-last layout
Conv2d(kind="fprop", A=A, B=B, C=C, D=D)
# Explicitly specify the data type for only some of A, B, C, and D. Unspecified data types will inherit
# those passed in via the generic ``element``
Conv2d(kind="fprop", element_A=cutlass.DataType.f32, element_accumulator=cutlass.DataType.f32,
element=cutlass.DataType.f32)
The order of precedence for the setting of the data type for a given operand/output is as follows:
1) If the tensor type is specified (e.g., ``A``), use the data type inferred from this tensor
2) Otherwise, if the data type (e.g., ``element_A``) is specified, use those
3) Otherwise, use the generic values (e.g., ``element``)
:param kind: the convolution kind (i.e. fprop, wgrad, and dgrad)
:type kind: str
:param A: tensor representing data type of operand A
:param B: tensor representing data type of operand B
:param C: tensor representing data type of operand C
:param D: tensor representing data type of operand D
:param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
:param beta: scalar parameter beta from GEMM operation that scales operand C
:param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type
:type element: cutlass.DataType
:param element_A: data type to be used for operand A
:type element_A: cutlass.DataType
:param element_B: data type to be used for operand B
:type element_B: cutlass.DataType
:param element_C: data type to be used for operand C
:type element_C: cutlass.DataType
:param element_D: data type to be used for operand D
:type element_D: cutlass.DataType
:param element_accumulator: data type to be used in accumulation of the product of operands A and B
:type element_accumulator: cutlass.DataType
:param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90
:type cc: int
:param kernel_cc: compute capability of kernels to generate. For example, if running on SM90, but desiring to use a CUTLASS 2.x-style Ampere kernel, this should be set to 80
:type kernel_cc: int
"""
def __init__(
self, kind="fprop",
A=None, B=None, C=None, D=None, alpha=1.0, beta=0.0,
element=None,
element_A=None, element_B=None, element_C=None, element_D=None,
element_accumulator=None,
cc: int = None, kernel_cc: int = None
):
super().__init__(cc=cc, kernel_cc=kernel_cc, operation_kind=cutlass.OperationKind.Conv2d)
# Verify the kernel cc
if self.current_cc == 90:
# The Conv2d kernel on Hopper (SM90) is currently unsupported
# Revert to use SM80-tagged kernels
cutlass.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.")
self.specified_kernel_cc = 80
self._reset_options(80)
# The arch is used in testing
self.arch = self.current_cc
self.name = "conv2d" + kind
# The convolution kind. (concept: cutlass_bindings.conv.Operator)
self.conv_kind = getattr(cutlass_bindings.conv.Operator, kind)
# The element types (concept: cutlass library types) of A, B, C, and D
elements = []
layouts = []
# Complete the data types based on user-provided arguments
for elt, tens, name in zip([element_A, element_B, element_C, element_D],
[A, B, C, D],
["A", "B", "C", "D"]):
if elt is not None and tens is not None:
raise Exception(f'Must not specify both element_{name} and tensor {name}')
if elt is None and tens is None and element is None:
raise Exception(f'Must specify one of element_{name}, tensor {name}, or generic element.')
elt_to_set = None
lay_to_set = None
if tens is not None:
elt_to_set, _ = datatypes.get_datatype_and_layout(tens)
else:
elt_to_set = elt if elt is not None else element
assert elt_to_set is not None
# Currently we only support layout TensorNHWC
lay_to_set = cutlass.LayoutType.TensorNHWC
elements.append(datatypes.library_type(elt_to_set))
layouts.append(lay_to_set)
self._element_a, self._element_b, self._element_c, self._element_d = elements
self._layout_a, self._layout_b, self._layout_c, self._layout_d = layouts
self.A, self.B, self.C, self.D, self.alpha, self.beta = A, B, C, D, alpha, beta
if element_accumulator is None:
self._element_accumulator = self._element_c
else:
self._element_accumulator = datatypes.library_type(element_accumulator)
# Default inputs if none is supplied in run()
self.A = A
self.B = B
self.C = C
self.D = D
self.alpha = alpha
self.beta = beta
# We only specify the stride of the swizzling functor here
# The actual swizzling functor is determined in run based on conv_kind and stride
self._swizzling_stride = 1
# Arguments that will be set to default value in _reset_operations
# The default tile_description and op_class are fetched from manifest of cutlass library
self._tile_description = None
self.op_class = None
# The default identity epilogue will be created
self.epilogue_functor = None
self._reset_operations()
# Arguments that will be determined online based on arguments of "run"
# based on stride, input/output channels, alignment, and conv_kind
self._iterator_algorithm = None
self._stride_support = None
def _reset_operations(self, reset_epilogue: bool = True):
# Set the default op class
datatype_comb = (self._element_a, self._element_b, self._element_accumulator)
layout_comb = (self._layout_a, self._layout_b)
self.possible_op_classes = self.options.supporting_opclasses(
self._element_a, self._element_b, self._element_accumulator,
self._layout_a, self._layout_b
)
if cutlass.OpcodeClass.TensorOp in self.possible_op_classes:
self.opclass = cutlass.OpcodeClass.TensorOp
elif cutlass.OpcodeClass.Simt in self.possible_op_classes:
self.opclass = cutlass.OpcodeClass.Simt
else:
raise Exception(f'No kernel configuration found for supported data type and layout '
f'combination {datatype_comb}x{layout_comb}')
if reset_epilogue:
self._reset_epilogue_functor_activation(epilogue.identity)
self.alignment_pref_A = min(
128 // cutlass.DataTypeSize[self._element_a], max(self.possible_operations.alignments))
self.alignment_pref_B = min(
128 // cutlass.DataTypeSize[self._element_b], max(self.possible_operations.alignments))
self.alignment_pref_C = min(
128 // cutlass.DataTypeSize[self._element_c], max(self.possible_operations.alignments))
#
# Tile description Related
#
@property
def tile_description(self) -> TileDescription:
"""
Returns the tile description
"""
return self._tile_description
@tile_description.setter
def tile_description(
self, td=None):
"""
Set the tile description
:param td: tile description
:type td: cutlass.backend.TileDescription, or a dict with keys
{
"threadblock_shape": [int, int, int],
"warp_count": [int, int, int],
"stages": int,
"instruction_shape": [int, int, int] (optional),
"cluster_shape": [int, int, int] (optional)
}
"""
if td is None:
return
if isinstance(td, dict):
if self._tile_description is None:
alignment = list(self.possible_operations.kernels_by_alignment.keys())[0]
op = self.possible_operations.operations(alignment)[0]
self._tile_description = datatypes.td_from_profiler_op(op)
if "cluster_shape" in td.keys():
if td["cluster_shape"] != [1, 1, 1]:
cutlass.logger.warning("Conv2d currently only support 'cluster_shape'=[1, 1, 1]'.")
td["cluster_shape"] = [1, 1, 1]
td = self._tile_description.clone_and_update(td)
valid, msg = self._valid_tile_description(td)
if valid:
self._tile_description = td
else:
raise Exception(msg)
def _valid_tile_description(self, td: TileDescription) -> tuple:
"""
Checks whether the provided tile description is valid for the given compute capability. At present,
this checks the following:
- Does the tile description use a number of stages supported by the compute capability in question?
- Does the tile size requested fit within shared memory?
- Are cluster dimensions outside the valid range requested for a given architecture (e.g.,
more non-unit cluster dimensions for pre-SM90 architectures)?
- Is the kernel schedule being used supported on the architecture in question?
:param td: tile description to validate
:type td: cutlass.backend.TileDescription
:return: tuple in which the first element is a bool indicating that the tile description is valid
and the second element is a string providing an optional error message.
:rtype: tuple
"""
# Check stage count based on the CC to which we are compiling (self.cc), rather
# than the CC from which we find kernels (self.current_cc)
valid, msg = check.valid_stage_count(self.cc, td)
if not valid:
return (valid, msg)
valid, msg = check.valid_cluster_shape(self.current_cc, td.cluster_shape)
if not valid:
return (valid, msg)
return valid, msg
def tile_descriptions(self) -> list:
"""
Returns a list of valid tile descriptions for the operations
:returns: list of valid tile descriptions for the operations
:rtype: list
"""
descriptions = []
description_str = []
for op in self.possible_operations.all_operations:
td = datatypes.td_from_profiler_op(op)
if str(td) not in description_str:
description_str.append(str(td))
descriptions.append(td)
return descriptions
#
# Swizzling functor Related
#
@property
def swizzling_stride(self):
"""
Returns the stride of swizzling currently being used by the Conv2d
:return: swizzing stride
"""
return self._swizzling_stride
@swizzling_stride.setter
def swizzling_stride(self, stride: int):
"""
Sets the swizzling functor to the type specified by `swizzling_functor`
"""
if not isinstance(stride, int):
raise Exception(f"Expect integer (1, 2, 4, 8), got {stride}")
self._swizzling_stride = stride
def _propose_swizzling_functor(self, stride):
"""
Automatically propose the swizzling functor based on the stride
"""
if self.conv_kind == cutlass_bindings.conv.Operator.dgrad:
if stride[0] != 1 or stride[1] != 1:
return getattr(cutlass.swizzle, f"StridedDgradIdentitySwizzle{self._swizzling_stride}")
return getattr(cutlass.swizzle, f"IdentitySwizzle{self._swizzling_stride}")
#
# Iterator Algorithm Related
#
@property
def iterator_algorithm(self) -> cutlass_bindings.conv.IteratorAlgorithm:
"""
Returns the iterator algorithm
"""
return self._iterator_algorithm
@iterator_algorithm.setter
def iterator_algorithm(self, alg: str):
"""
Sets the iterator algorithm
:param alg: The iterator algorithm
:type td: string, options: "analytic", "optimized", "few_channels", and "fixed_channels"
"""
# Check if the iterator algorithm is valid
if alg in ["few_channels", "fixed_channels"] and self.conv_kind != cutlass_bindings.conv.Operator.fprop:
raise Exception(f"{self.conv_kind} does not support iterator algorithm {alg}.")
self._iterator_algorithm = getattr(cutlass_bindings.conv.IteratorAlgorithm, alg)
def _propose_iterator_algorithm(self, problem_size, alignment_a, alignment_b) -> cutlass_bindings.conv.IteratorAlgorithm:
"""
Propose a valid iterator algorithm based on problem size and alignment
"""
if self.conv_kind == cutlass_bindings.conv.Operator.fprop:
# Check whether the fixed channel is applicable
if problem_size.C == alignment_a:
return cutlass_bindings.conv.IteratorAlgorithm.fixed_channels
elif (problem_size.C % alignment_a == 0 and
problem_size.R <= 32 and problem_size.S <= 32):
return cutlass_bindings.conv.IteratorAlgorithm.optimized
else:
return cutlass_bindings.conv.IteratorAlgorithm.analytic
elif self.conv_kind == cutlass_bindings.conv.Operator.dgrad:
if (problem_size.K % alignment_a == 0 and
problem_size.R <= 32 and problem_size.S <= 32 and
problem_size.C % alignment_b == 0):
return cutlass_bindings.conv.IteratorAlgorithm.optimized
else:
return cutlass_bindings.conv.IteratorAlgorithm.analytic
elif self.conv_kind == cutlass_bindings.conv.Operator.wgrad:
if (problem_size.K % alignment_a == 0 and
problem_size.C % alignment_b == 0):
return cutlass_bindings.conv.IteratorAlgorithm.optimized
else:
return cutlass_bindings.conv.IteratorAlgorithm.analytic
def _validate_iterator_algorithm(self, iterator_algorithm, problem_size, alignment_a, alignment_b) -> bool:
"""
Validate whether the user provide iterator algorithm works for the given problem size
"""
if self.conv_kind == cutlass_bindings.conv.Operator.fprop:
if iterator_algorithm == cutlass_bindings.conv.IteratorAlgorithm.fixed_channels:
return problem_size.C == alignment_a
elif iterator_algorithm == cutlass_bindings.conv.IteratorAlgorithm.optimized:
return (problem_size.C % alignment_a == 0 and
problem_size.R <= 32 and problem_size.S <= 32)
elif iterator_algorithm == cutlass_bindings.conv.IteratorAlgorithm.few_channels:
return problem_size.C % alignment_a == 0
elif self.conv_kind == cutlass_bindings.conv.Operator.dgrad:
if iterator_algorithm == cutlass_bindings.conv.IteratorAlgorithm.optimized:
return (problem_size.K % alignment_a == 0 and
problem_size.R <= 32 and problem_size.S <= 32 and
problem_size.C % alignment_b == 0)
elif self.conv_kind == cutlass_bindings.conv.Operator.wgrad:
if iterator_algorithm == cutlass_bindings.conv.IteratorAlgorithm.optimized:
return (problem_size.K % alignment_a == 0 and
problem_size.C % alignment_b == 0)
return True
#
# Stride Support Related
#
def _propose_stride_support(self, stride):
if self.conv_kind == cutlass_bindings.conv.Operator.dgrad:
if stride[0] == 1 and stride[1] == 1:
return cutlass.backend.library.StrideSupport.Unity
return cutlass.backend.library.StrideSupport.Strided
#
# Construct and Compilation
#
def construct(
self, tile_description: TileDescription = None,
alignment_A: int = None, alignment_B: int = None, alignment_C: int = None,
iterator_algorithm: cutlass_bindings.conv.IteratorAlgorithm = None,
stride_support = None, swizzling_functor: cutlass.swizzle = None,
epilogue_functor=None) -> cutlass.backend.Conv2dOperation:
"""
Constructs a ``cutlass.backend.Conv2dOperation`` based on the input parameters and current
kernel specification of the ``Conv2d`` object.
:param tile_description: tile description specifying shapes and operand types to use in the kernel
:type tile_description: cutlass.backend.TileDescription
:param alignment_A: alignment of operand A
:type alignment_A: int
:param alignment_B: alignment of operand B
:type alignment_B: int
:param alignment_C: alignment of operand C
:type alignment_C: int
:param iterator_algorithm: the iterator algorithm used
:type iterator_algorithm: cutlass_bindings.conv.IteratorAlgorithm
:param stride_support: the stride support of dgrad
:type stride_support: cutlass.backend.library.StrideSupport
:param swizzling_functor: the swizzling functor
:type swizzling_functor: cutlass.swizzle
:param epilogue_functor: the epilogue functor
:return: operation that was constructed
:rtype: cutlass.backend.Conv2dOperation
"""
# Get alignment
alignment_A = check.alignment_or_default(alignment_A, self.alignment_pref_A)
alignment_B = check.alignment_or_default(alignment_B, self.alignment_pref_B)
alignment_C = check.alignment_or_default(alignment_C, self.alignment_pref_C)
tensor_A = TensorDescription(
datatypes.binding_type(self._element_a),
datatypes.binding_layout(self._layout_b),
alignment_A
)
tensor_B = TensorDescription(
datatypes.binding_type(self._element_b),
datatypes.binding_layout(self._layout_b),
alignment_B
)
tensor_C = TensorDescription(
datatypes.binding_type(self._element_c),
datatypes.binding_layout(self._layout_c),
alignment_C
)
if tile_description is None:
if self.tile_description is not None:
tile_description = self.tile_description
else:
op = self.possible_operations.operations(alignment_A)[0]
tile_description = datatypes.td_from_profiler_op(op)
else:
valid, err_str = self._valid_tile_description(tile_description)
if not valid:
raise Exception(f"Invalid tile description. {err_str}")
self.tile_description = tile_description
if iterator_algorithm is None:
# If the iterator algorithm is already set
if self.iterator_algorithm is not None:
iterator_algorithm = self.iterator_algorithm
else:
# Otherwise, we conservatively use the analytic iterator for correctness
iterator_algorithm = cutlass_bindings.conv.IteratorAlgorithm.analytic
if stride_support is None:
# If the stride support is already set
if self._stride_support is not None:
stride_support = self._stride_support
else:
# Otherwise, we assume strided
stride_support = cutlass.backend.library.StrideSupport.Strided
if swizzling_functor is None:
# If the swizzling functor is already set
swizzling_functor = self._propose_swizzling_functor(stride=(2, 2))
if epilogue_functor is None:
if self.epilogue_functor is not None:
epilogue_functor = self.epilogue_functor
else:
epilogue_functor = self._create_epilogue_functor_activation(self._activation)
# Reset the alignment of the epilogue functor
epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, epilogue_functor)
operation = Conv2dOperation(
conv_kind=self.conv_kind,
iterator_algorithm=iterator_algorithm,
arch=self.current_cc,
tile_description=tile_description,
A=tensor_A, B=tensor_B, C=tensor_C,
stride_support=stride_support,
epilogue_functor=epilogue_functor,
swizzling_functor=swizzling_functor,
)
return operation
def compile(self, tile_description: TileDescription = None,
alignment_A: int = None, alignment_B: int = None, alignment_C: int = None,
iterator_algorithm: cutlass_bindings.conv.IteratorAlgorithm = None,
stride_support = None, swizzling_functor: cutlass.swizzle = None,
epilogue_functor = None, print_module: bool = False) -> cutlass.backend.Conv2dOperation:
"""
Emits and compiles the kernel currently specified. If ``tile_description`` and any
of the ``alignment`` parameters are set, the kernel will be chosen using this
tile description and alignments. Otherwise, a default tile description and alignment
will be used.
::param tile_description: tile description specifying shapes and operand types to use in the kernel
:type tile_description: cutlass.backend.TileDescription
:param alignment_A: alignment of operand A
:type alignment_A: int
:param alignment_B: alignment of operand B
:type alignment_B: int
:param alignment_C: alignment of operand C
:type alignment_C: int
:param iterator_algorithm: the iterator algorithm used
:type iterator_algorithm: cutlass_bindings.conv.IteratorAlgorithm
:param stride_support: the stride support of dgrad
:type stride_support: cutlass.backend.library.StrideSupport
:param swizzling_functor: the swizzling functor
:type swizzling_functor: cutlass.swizzle
:param epilogue_functor: the epilogue functor
:return: operation that was compiled
:rtype: cutlass.backend.Conv2dOperation
"""
self.operation = self.construct(
tile_description, alignment_A, alignment_B, alignment_C,
iterator_algorithm, stride_support, swizzling_functor, epilogue_functor)
if print_module:
print(self.operation.rt_module.emit())
compiler.add_module([self.operation,])
return self.operation
#
# Run Related
#
def _verify_type_and_layout(self, tensor, ref_type, ref_layout, name):
"""
Verifies that ``tensor`` has data type ``ref_type`` and layout ``ref_layout``. An exception
is raised if it does not.
:param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
:type tensor: numpy/cupy/torch array/tensor object
:param ref_dtype: data type for the tensor that this object was initialized to
:param name: identifier of the tensor to verify. Used in raising exceptions
:type name: str
"""
dtype, _ = datatypes.get_datatype_and_layout(tensor)
if dtype != ref_type:
raise Exception(f'Tensor {name} with type and layout {dtype} '
f'does not match the expected type of {ref_type}.')
def _get_and_verify_conv_problem_size(self, A, B, C, stride, padding, dilation):
if self.conv_kind == cutlass_bindings.conv.Operator.fprop:
input = A
weight = B
output = C
output_tensor = "C"
elif self.conv_kind == cutlass_bindings.conv.Operator.dgrad:
output = A
weight = B
input = C
output_tensor = "A"
elif self.conv_kind == cutlass_bindings.conv.Operator.wgrad:
output = A
input = B
weight = C
output_tensor = "A"
else:
raise Exception(f"Convolution kind {self.conv_kind} is not supported")
N_, H_, W_, C_ = datatypes.get_tensor_shape(input)
K_, R_, S_, _ = datatypes.get_tensor_shape(weight)
_, P_, Q_, _ = datatypes.get_tensor_shape(output)
problem_size = cutlass_bindings.conv.Conv2dProblemSize(
cutlass_bindings.Tensor4DCoord(N_, H_, W_, C_),
cutlass_bindings.Tensor4DCoord(K_, R_, S_, C_),
cutlass_bindings.Tensor4DCoord(padding[0], padding[0], padding[1], padding[1]),
cutlass_bindings.MatrixCoord(stride[0], stride[1]),
cutlass_bindings.MatrixCoord(dilation[0], dilation[1]),
cutlass_bindings.conv.Mode.cross_correlation,
1, 1
)
if P_ != problem_size.P or Q_ != problem_size.Q:
raise Exception(
f"Tensor {output_tensor} size should be ({N_}, {problem_size.P}, {problem_size.Q}, {K_}), got ({N_}, {P_}, {Q_}, {K_})")
return problem_size
def run(self, A=None, B=None, C=None, D=None,
stride=(1, 1), padding=(0, 0), dilation=(1, 1),
alpha=None, beta=None,
split_k=("serial", 1), sync: bool = True,
print_module: bool = False) -> Conv2dArguments:
"""
Runs the kernel currently specified. If it has not already been, the kernel is emitted and
compiled. Tensors holding operands and outputs of the kernel are sourced either from the
``A``, ``B``, ``C``, ``D``, ``alpha``, and ``beta``
parameters provided in the call, or from those
passed in on the construction of this object -- one of the two must be specified.
By default, this call returns only once the kernel has completed. To launch the kernel
and immediately return, set ``sync=False``. In this case, it is the responsibility of the
caller to syncrhonize the results of the kernel before attempting to access outputs
by calling ``sync()`` on the arguments returned from this call.
:param A: tensor representing data type and layout of operand A
:param B: tensor representing data type and layout of operand B
:param C: tensor representing data type and layout of operand C
:param D: tensor representing data type and layout of operand D
:param stride: (stride_h, stride_w) describing the convolution stride. Default: (1, 1)
:param padding: (pad_h, pad_w) describing the convolution padding. Default: (0, 0)
:param dilation: (dilation_h, dilation_w) describing the dilation of convolution. Default: (1, 1)
:param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
:param beta: scalar parameter beta from GEMM operation that scales operand C
:param split_k: a tuple (split_k_mode, split_k_slices)
:param sync: whether the call should wait for the kernel to complete before returning
:type sync: bool
:param print_module: whether to print the emitted C++ code
:type print_module: bool
:return: arguments passed in to the kernel
:rtype: cutlass.backend.Conv2dArguments
"""
A = self._verify_tensor(A, self.A, self._element_a, self._layout_a, "A")
B = self._verify_tensor(B, self.B, self._element_b, self._layout_b, "B")
C = self._verify_tensor(C, self.C, self._element_c, self._layout_c, "C")
D = self._verify_tensor(D, self.D, self._element_d, self._layout_d, "D")
alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha")
beta = self._verify_scalar(beta, self.beta, self._element_c, "beta")
# handle the case when there is no C
if C is None:
if beta != 0:
raise Exception(f"With beta {beta} != 0, C has to be provided.")
else:
C = D
# Construct problem size based on input
# It also verifies whether the A, B, C, D, stride, padding, and dilation are matching
problem_size = self._get_and_verify_conv_problem_size(A, B, C, stride, padding, dilation)
# Propose stride support based on input
stride_support = self._propose_stride_support(stride)
# Propose swizzling functor
swizzling_functor = self._propose_swizzling_functor(stride)
# Get the alignment
alignment_a = self.possible_operations.find_alignment(datatypes.get_tensor_shape(A), self._layout_a)
alignment_b = self.possible_operations.find_alignment(datatypes.get_tensor_shape(B), self._layout_b)
alignment_c = self.possible_operations.find_alignment(datatypes.get_tensor_shape(C), self._layout_c)
alignment_a = check.update_alignment(alignment_a, self.alignment_pref_A)
alignment_b = check.update_alignment(alignment_b, self.alignment_pref_B)
alignment_c = check.update_alignment(alignment_c, self.alignment_pref_C)
# Propose iterator algorithm based on input
if self._iterator_algorithm is None:
# Propose a default itertaor algorithm based on the problem size
iterator_algorithm = self._propose_iterator_algorithm(problem_size, alignment_a, alignment_b)
else:
if (self._validate_iterator_algorithm(self._iterator_algorithm, problem_size, alignment_a, alignment_b)):
iterator_algorithm = self._iterator_algorithm
else:
raise Exception(f"Iterator algorithm {self._iterator_algorithm} is invalid for current problem.")
epilogue_args = [alpha, beta]
if hasattr(self, "_activation_args"):
if isinstance(self._activation_args, list):
epilogue_args += self._activation_args
else:
epilogue_args.append(self._activation_args)
if split_k[0] == "parallel" and split_k[1] > 1:
epilogue_functor = self._create_epilogue_functor_activation(epilogue.identity)
else:
epilogue_functor = self.epilogue_functor
# The alignment is determined by the iterator function (I believe)
self.compile(tile_description=self.tile_description, alignment_A=alignment_a, alignment_B=alignment_b,
alignment_C=alignment_c, iterator_algorithm=iterator_algorithm, stride_support=stride_support,
swizzling_functor=swizzling_functor, epilogue_functor=epilogue_functor, print_module=print_module)
# Create reduction operation for parallel split-k
if split_k[0] == "parallel" and split_k[1] > 1:
epilogue_functor_reduction = self._reset_epilogue_functor_alignment(alignment_c, self.epilogue_functor)
self.reduction_operation = ReductionOperation(
shape=cutlass_bindings.MatrixCoord(4, 32 * alignment_c), C=self.operation.C,
element_accumulator=datatypes.binding_type(self._element_accumulator),
element_compute=datatypes.binding_type(self._element_accumulator),
epilogue_functor=epilogue_functor_reduction,
count=alignment_c
)
if print_module:
print(self.reduction_operation.rt_module.emit())
compiler.add_module([self.reduction_operation,])
arguments = Conv2dArguments(
operation=self.operation, problem_size=problem_size,
A=A, B=B, C=C, D=D,
output_op=self.operation.epilogue_type(*epilogue_args),
split_k_mode=datatypes.getattr_enum(cutlass_bindings.conv.SplitKMode, split_k[0]),
split_k_slices=split_k[1]
)
self.operation.run(arguments)
if split_k[0] == "parallel" and split_k[1] > 1:
implicit_gemm_size = cutlass_bindings.conv.implicit_gemm_problem_size(
self.conv_kind, arguments.problem_size
)
reduction_arguments = ReductionArguments(
self.reduction_operation,
problem_size=[implicit_gemm_size.m(), implicit_gemm_size.n()],
partitions=split_k[1],
workspace=arguments.ptr_D,
destination=D,
source=C,
output_op=self.reduction_operation.epilogue_type(*epilogue_args)
)
self.reduction_operation.run(reduction_arguments)
if sync:
if split_k[0] == "parallel" and split_k[1] > 1:
reduction_arguments.sync()
else:
arguments.sync()
return arguments
#
# Helper functions
#
@staticmethod
def output_size(input_size, weight_size, padding, stride, dilation):
problem_size = cutlass_bindings.conv.Conv2dProblemSize(
cutlass_bindings.Tensor4DCoord(*input_size),
cutlass_bindings.Tensor4DCoord(*weight_size),
cutlass_bindings.Tensor4DCoord(padding[0], padding[0], padding[1], padding[1]),
cutlass_bindings.MatrixCoord(stride[0], stride[1]),
cutlass_bindings.MatrixCoord(dilation[0], dilation[1]),
cutlass_bindings.conv.Mode.cross_correlation,
1, 1
)
return (problem_size.N, problem_size.P, problem_size.Q, problem_size.K)
#
# Easy to use interfaces for fprop, wgrad, and dgrad
#
class Conv2dFprop(Conv2d):
def __init__(
self,
input=None, weight=None, C=None, output=None, alpha=1, beta=0,
element=None,
element_input=None, element_weight=None, element_C=None, element_output=None,
element_accumulator=None,
cc: int = None, kernel_cc: int = None):
A, B, D = input, weight, output
element_A, element_B, element_D = element_input, element_weight, element_output
super().__init__(
"fprop", A, B, C, D, alpha, beta, element,
element_A, element_B, element_C, element_D,
element_accumulator, cc, kernel_cc)
def run(
self, input=None, weight=None, C=None, output=None, alpha=None, beta=None,
stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
sync: bool = True, print_module: bool = False) -> Conv2dArguments:
A, B, D = input, weight, output
return super().run(
A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module)
class Conv2dDgrad(Conv2d):
def __init__(
self,
grad_output=None, weight=None, C=None, grad_input=None, alpha=1, beta=0,
element=None,
element_grad_output=None, element_weight=None, element_C=None, element_grad_input=None,
element_accumulator=None,
cc: int = None, kernel_cc: int = None):
A, B, D = grad_output, weight, grad_input
element_A, element_B, element_D = element_grad_output, element_weight, element_grad_input
super().__init__(
"dgrad", A, B, C, D, alpha, beta, element,
element_A, element_B, element_C, element_D,
element_accumulator, cc, kernel_cc)
def run(self, grad_output=None, weight=None, C=None, grad_input=None, alpha=None, beta=None,
stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
sync: bool = True, print_module: bool = False) -> Conv2dArguments:
#
A, B, D = grad_output, weight, grad_input
return super().run(
A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module)
class Conv2dWgrad(Conv2d):
def __init__(
self,
grad_output=None, input=None, C=None, grad_weight=None, alpha=1, beta=0,
element=None,
element_grad_output=None, element_input=None, element_C=None, element_grad_weight=None,
element_accumulator=None,
cc: int = None, kernel_cc: int = None):
A, B, D = grad_output, input, grad_weight
element_A, element_B, element_D = element_grad_output, element_input, element_grad_weight
super().__init__(
"wgrad", A, B, C, D, alpha, beta, element,
element_A, element_B, element_C, element_D,
element_accumulator, cc, kernel_cc)
def run(self, grad_output=None, input=None, C=None, grad_weight=None, alpha=None, beta=None,
stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
sync: bool = True, print_module: bool = False) -> Conv2dArguments:
#
A, B, D = grad_output, input, grad_weight
return super().run(
A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module)

View File

@@ -287,108 +287,6 @@ class Gemm(OperationBase):
if reset_epilogue:
self._reset_epilogue_functor_activation(epilogue.identity)
def _reset_epilogue_functor_activation(self, activation):
if self.epilogue_functor is None:
if self.op_class == cutlass.OpcodeClass.Simt:
elements_per_access = 1
else:
elements_per_access = 128 // cutlass.DataTypeSize[self._element_c]
else:
elements_per_access = self.epilogue_functor.epilogue_vector_length
if not self.specified_kernel_cc:
if self.current_cc == 90 and activation != epilogue.identity:
# CUTLASS 3.0 kernels currently only support identity activation. If one requests a non-identity activation,
# revert to using a CUTLASS 2.x kernel by using SM80-tagged kernels.
cutlass.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.")
self._reset_options(80)
self._reset_operations(reset_epilogue=False)
elif (self.cc == 90 and self.current_cc != 90 and activation == epilogue.identity):
# SM80 fallback kernels are currently used. Since an identity activation is requested,
# we can switch back to using SM90 kernels.
self._reset_options(90)
self._reset_operations(reset_epilogue=False)
else:
if self.current_cc == 90 and activation != epilogue.identity:
raise Exception("Epilogues with elementwise fusion are not currently supported "
"in the Python interface for 3.x kernels. To use 2.x kernels "
"with fused elementwise epilogues, do not set the `kernel_cc` "
"parameter when constructing the Gemm object.")
self.epilogue_functor = epilogue.get_activation_epilogue(
activation,
datatypes.binding_type(self._element_c),
elements_per_access,
datatypes.binding_type(self._element_accumulator),
datatypes.binding_type(self._element_accumulator),
)
def _reset_epilogue_functor_alignment(self, alignment):
if self.epilogue_functor is None or not hasattr(self.epilogue_functor, 'activation_functor'):
activation = epilogue.identity
else:
activation = type(self.epilogue_functor.activation_functor)
self.epilogue_functor = epilogue.get_activation_epilogue(
activation,
datatypes.binding_type(self._element_c),
alignment,
datatypes.binding_type(self._element_accumulator),
datatypes.binding_type(self._element_accumulator),
)
@property
def activation(self):
"""
Returns the type of the current activation function used
"""
return type(self.epilogue_functor.activation_functor)
@activation.setter
def activation(self, act):
"""
Sets the type of the activation function to use
"""
self._reset_epilogue_functor_activation(act)
@property
def opclass(self) -> cutlass.OpcodeClass:
"""
Returns the opcode class currently in use by the GEMM
:return: opcode class currently in use
:rtype: cutlass.OpcodeClass
"""
return self.op_class
@opclass.setter
def opclass(self, oc: cutlass.OpcodeClass):
"""
Sets the opcode class to use in the GEMM. If the opcode class is not supported under
the given compute capability and element/layout combinations of the GEMM, an exception is raised.
"""
if oc in self.possible_op_classes:
self.op_class = oc
else:
raise Exception(
f'Unsupported operation class {oc} for CC {self.cc} and data type combination '
f'({self._element_a}, {self._element_b}, {self._element_accumulator}) and '
f'layout combination ({self._layout_a}, {self._layout_b}).')
# Changing the op class changes the elements per access in the epilogue. Reset this.
if self.op_class == cutlass.OpcodeClass.Simt:
elements_per_access = 1
else:
elements_per_access = 128 // cutlass.DataTypeSize[self._element_c]
if self.epilogue_functor is not None:
self._reset_epilogue_functor_alignment(elements_per_access)
# Changing the op class also changes the possible operations available. Reset these.
self.possible_operations = self.options.operations(
self.op_class, self._element_a, self._element_b,
self._element_accumulator, self._layout_a, self._layout_b)
@property
def swizzling_functor(self):
"""
@@ -430,7 +328,7 @@ class Gemm(OperationBase):
"""
# Check stage count based on the CC to which we are compiling (self.cc), rather
# than the CC from which we find kernels (self.current_cc)
valid, msg = check.valid_stage_count(self.cc, td)
valid, msg = check.valid_stage_count(self.cc, td, self._element_c, self._element_d)
if not valid:
return (valid, msg)
@@ -438,7 +336,7 @@ class Gemm(OperationBase):
if not valid:
return (valid, msg)
valid, msg = check.valid_kernel_schedule(self.current_cc, td.kernel_schedule)
valid, msg = check.valid_schedule(self.current_cc, td.kernel_schedule, td.epilogue_schedule, td.tile_scheduler)
return valid, msg
def tile_descriptions(self) -> list:
@@ -476,7 +374,7 @@ class Gemm(OperationBase):
alignment_B = check.alignment_or_default(alignment_B, alignment_pref_B)
alignment_C = check.alignment_or_default(alignment_C, alignment_pref_C)
self._reset_epilogue_functor_alignment(alignment_C)
self.epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, self.epilogue_functor)
tensor_A = TensorDescription(
datatypes.binding_type(self._element_a),
@@ -562,68 +460,6 @@ class Gemm(OperationBase):
f'does not match the expected type and '
f'layout of ({ref_type}, {ref_layout}).')
def _verify_tensor(self, tensor, ref_tensor, ref_dtype, ref_layout, name):
"""
Verifies the following properties:
1) Either ``tensor`` or ``ref_tensor`` must be set (i.e., not ``None``)
2) If ``tensor`` is not ``None``, its datatype and layout must match matches the current versions
set by the plan (i.e., those in ``ref_dtype`` and ``ref_layout``)
If either of these properties does not hold, an exception is raised. If these properties hold and
``tensor`` is not ``None``, ``tensor`` is returned. Otherwise, ``ref_tensor`` is returned.
:param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
:type tensor: numpy/cupy/torch array/tensor object
:param ref_tensor: object representing a tensor passed in on construction of this object, or ``None`` if no tensor was passed in
:type ref_tensor: numpy/cupy/torch array/tensor object
:param ref_dtype: data type for the tensor that this object was initialized to
:param ref_layout: layout for the tensor that this object was initialized to
:param name: identifier of the tensor to verify. Used in raising exceptions
:type name: str
:return: valid tensor object to use
:rtype: numpy/cupy/torch array/tensor object
"""
if tensor is None:
if ref_tensor is None:
raise Exception(f"Tensor {name} must be set.")
return ref_tensor
self._verify_type_and_layout(tensor, ref_dtype, ref_layout, name)
return tensor
def _verify_scalar(self, scalar, ref_scalar, ref_dtype, name):
"""
Verifies the following properties:
1) Either ``scalar`` or ``ref_scakar`` must be set (i.e., not ``None``)
2) If ``scalar`` is not ``None``, its datatype must match matches the current version
set by the plan (i.e., those in ``ref_dtype``)
If either of these properties does not hold, an exception is raised. If these properties hold and
``scalar`` is not ``None``, ``scalar`` is returned. Otherwise, ``ref_scalar`` is returned.
:param scalar: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
:type scalar: numpy/cupy/torch scalar
:param ref_scalar: object representing a tensor passed in on construction of this object, or ``None`` if no tensor was passed in
:type ref_scalar: numpy/cupy/torch scalar
:param ref_dtype: data type for the scalar that this object was initialized to
:param name: identifier of the scalar to verify. Used in raising exceptions
:type name: str
:return: valid scalar to use
:rtype: numpy/cupy/torch scalar
"""
if scalar is None:
if ref_scalar is None:
raise Exception(f"Scalar {name} must be set.")
return ref_scalar
dtype = datatypes.library_type(scalar.dtype)
if dtype != ref_dtype:
raise Exception(
f"Tensor {name} with type {dtype} does not match expected type {ref_dtype}."
)
return scalar
def run(self, A=None, B=None, C=None, D=None,
alpha=None, beta=None, batch_count: int = 1,
sync: bool = True, print_module: bool = False) -> GemmArguments:

View File

@@ -168,7 +168,7 @@ class GroupedGemm(Gemm):
alignment_B = check.alignment_or_default(alignment_B, alignment_preference)
alignment_C = check.alignment_or_default(alignment_C, alignment_preference)
self._reset_epilogue_functor_alignment(alignment_C)
self.epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, self.epilogue_functor)
tensor_A = TensorDescription(
datatypes.binding_type(self._element_a),

View File

@@ -36,11 +36,13 @@ Base operation used for defining high-level CUTLASS operations (e.g., GEMM, Conv
from bisect import bisect_left
from cutlass import option_registry
import cutlass
from cutlass import option_registry, epilogue
from cutlass.backend.utils.device import device_cc
from cutlass.epilogue import get_activations
from cutlass.library_defaults import _generator_ccs
from cutlass.swizzle import get_swizzling_functors
from cutlass.utils import datatypes
class OperationBase:
@@ -48,22 +50,26 @@ class OperationBase:
Base operation used for defining high-level CUTLASS operations (e.g., GEMM, Conv2d)
"""
def __init__(self, cc: int = None, kernel_cc: int = None):
def __init__(self, cc: int = None, kernel_cc: int = None, operation_kind = cutlass.OperationKind.Gemm):
"""
:param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90
:type cc: int
:param kernel_cc: compute capability of kernels to generate. For example, if running on SM90, but desiring to use a CUTLASS 2.x-style Ampere kernel, this should be set to 80
:type kernel_cc: int
"""
self.operation_kind = operation_kind
self.cc = cc if cc is not None else device_cc()
self.specified_kernel_cc = kernel_cc is not None
self.current_cc = kernel_cc if kernel_cc is not None else self._find_closest_cc(self.cc)
self.tile_description = None
self.options = option_registry.options_for_cc(self.current_cc)
self.options = option_registry.options_for_cc(self.current_cc, operation_kind)
if self.options is None:
raise Exception(f"Invalid or unsupported compute capability: {self.current_cc}")
# Default activation function: identity
self._activation = epilogue.identity
def _find_closest_cc(self, cc: int) -> int:
"""
@@ -113,4 +119,210 @@ class OperationBase:
if cc not in _generator_ccs:
raise Exception(f'Invalid CC for CUTLASS kernels: {cc}.')
self.current_cc = cc
self.options = option_registry.options_for_cc(self.current_cc)
self.options = option_registry.options_for_cc(self.current_cc, self.operation_kind)
def _verify_scalar(self, scalar, ref_scalar, ref_dtype, name):
"""
Verifies the following properties:
1) Either ``scalar`` or ``ref_scakar`` must be set (i.e., not ``None``)
2) If ``scalar`` is not ``None``, its datatype must match matches the current version
set by the plan (i.e., those in ``ref_dtype``)
If either of these properties does not hold, an exception is raised. If these properties hold and
``scalar`` is not ``None``, ``scalar`` is returned. Otherwise, ``ref_scalar`` is returned.
:param scalar: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
:type scalar: numpy/cupy/torch scalar
:param ref_scalar: object representing a tensor passed in on construction of this object, or ``None`` if no tensor was passed in
:type ref_scalar: numpy/cupy/torch scalar
:param ref_dtype: data type for the scalar that this object was initialized to
:param name: identifier of the scalar to verify. Used in raising exceptions
:type name: str
:return: valid scalar to use
:rtype: numpy/cupy/torch scalar
"""
if scalar is None:
if ref_scalar is None:
raise Exception(f"Scalar {name} must be set.")
return ref_scalar
if hasattr(scalar, "dtype"):
dtype = datatypes.library_type(scalar.dtype)
if dtype != ref_dtype:
raise Exception(
f"Tensor {name} with type {dtype} does not match expected type {ref_dtype}."
)
return scalar
def _verify_tensor(self, tensor, ref_tensor, ref_dtype, ref_layout, name):
"""
Verifies the following properties:
1) Either ``tensor`` or ``ref_tensor`` must be set (i.e., not ``None``)
2) If ``tensor`` is not ``None``, its datatype and layout must match matches the current versions
set by the plan (i.e., those in ``ref_dtype`` and ``ref_layout``)
If either of these properties does not hold, an exception is raised. If these properties hold and
``tensor`` is not ``None``, ``tensor`` is returned. Otherwise, ``ref_tensor`` is returned.
:param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
:type tensor: numpy/cupy/torch array/tensor object
:param ref_tensor: object representing a tensor passed in on construction of this object, or ``None`` if no tensor was passed in
:type ref_tensor: numpy/cupy/torch array/tensor object
:param ref_dtype: data type for the tensor that this object was initialized to
:param ref_layout: layout for the tensor that this object was initialized to
:param name: identifier of the tensor to verify. Used in raising exceptions
:type name: str
:return: valid tensor object to use
:rtype: numpy/cupy/torch array/tensor object
"""
if tensor is None:
if ref_tensor is None:
raise Exception(f"Tensor {name} must be set.")
return ref_tensor
self._verify_type_and_layout(tensor, ref_dtype, ref_layout, name)
return tensor
#
# Opcode Related
#
@property
def opclass(self) -> cutlass.OpcodeClass:
"""
Returns the opcode class currently in use by the GEMM
:return: opcode class currently in use
:rtype: cutlass.OpcodeClass
"""
return self.op_class
@opclass.setter
def opclass(self, oc: cutlass.OpcodeClass):
if isinstance(oc, str):
oc = datatypes.getattr_enum(cutlass.OpcodeClass, oc)
if oc in self.possible_op_classes:
self.op_class = oc
else:
raise Exception(
f'Unsupported operation class {oc} for CC {self.cc} and data type combination '
f'({self._element_a}, {self._element_b}, {self._element_accumulator}) and '
f'layout combination ({self._layout_a}, {self._layout_b}).')
# Changing the op class changes the elements per access in the epilogue. Reset this.
if self.op_class == cutlass.OpcodeClass.Simt:
elements_per_access = 1
else:
elements_per_access = 128 // cutlass.DataTypeSize[self._element_c]
if self.epilogue_functor is not None:
self.epilogue_functor = self._reset_epilogue_functor_alignment(elements_per_access, self.epilogue_functor)
# Changing the op class also changes the possible operations available. Reset these.
self.possible_operations = self.options.operations(
self.op_class, self._element_a, self._element_b,
self._element_accumulator, self._layout_a, self._layout_b)
#
# Epilogue
#
def _create_epilogue_functor_activation(self, activation):
"""
Returns the epilogue functor with given activation function
"""
if self.epilogue_functor is None:
if self.op_class == cutlass.OpcodeClass.Simt:
elements_per_access = 1
else:
elements_per_access = 128 // cutlass.DataTypeSize[self._element_c]
else:
elements_per_access = self.epilogue_functor.epilogue_vector_length
if not self.specified_kernel_cc:
if self.current_cc == 90 and activation != epilogue.identity:
# CUTLASS 3.0 kernels currently only support identity activation. If one requests a non-identity activation,
# revert to using a CUTLASS 2.x kernel by using SM80-tagged kernels.
cutlass.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.")
self._reset_options(80)
self._reset_operations(reset_epilogue=False)
elif (self.cc == 90 and self.current_cc != 90 and activation == epilogue.identity):
# SM80 fallback kernels are currently used. Since an identity activation is requested,
# we can switch back to using SM90 kernels.
self._reset_options(90)
self._reset_operations(reset_epilogue=False)
else:
if self.current_cc == 90 and activation != epilogue.identity:
raise Exception("Epilogues with elementwise fusion are not currently supported "
"in the Python interface for 3.x kernels. To use 2.x kernels "
"with fused elementwise epilogues, do not set the `kernel_cc` "
"parameter when constructing the Gemm object.")
return epilogue.get_activation_epilogue(
activation,
datatypes.binding_type(self._element_c),
elements_per_access,
datatypes.binding_type(self._element_accumulator),
datatypes.binding_type(self._element_accumulator),
)
def _reset_epilogue_functor_activation(self, activation):
"""
Set the epilogue functor based on the provided activation function
"""
self.epilogue_functor = self._create_epilogue_functor_activation(activation)
def _reset_epilogue_functor_alignment(self, alignment, epilogue_functor):
"""
Reset the alignment of the current epilogue functor based on alignment C
"""
if epilogue_functor is None or not hasattr(epilogue_functor, 'activation_functor'):
# Identity epilogue does not have 'activation_functor'
activation = epilogue.identity
else:
activation = type(epilogue_functor.activation_functor)
epilogue_functor = epilogue.get_activation_epilogue(
activation,
datatypes.binding_type(self._element_c),
alignment,
datatypes.binding_type(self._element_accumulator),
datatypes.binding_type(self._element_accumulator),
)
return epilogue_functor
@property
def activation(self):
"""
Returns the type of the current activation function used
"""
if hasattr(self.epilogue_functor, "activation_functor"):
return type(self.epilogue_functor.activation_functor)
else:
return epilogue.identity
@activation.setter
def activation(self, act):
"""
Sets the type of the activation function to use
Activation can come with a set of arguments
:param act: type of activation function to use
:type act: str or tuple. e.g. "relu", ("leaky_relu", 0.01)
"""
if isinstance(act, tuple):
if isinstance(act[0], str):
act_fn = getattr(cutlass.backend.epilogue, act[0])
else:
act_fn = act[0]
self._reset_epilogue_functor_activation(act_fn)
self._activation_args = act[1]
self._activation = act[0]
else:
if isinstance(act, str):
act = getattr(cutlass.backend.epilogue, act)
self._reset_epilogue_functor_activation(act)
self._activation = act

View File

@@ -32,9 +32,10 @@
from cutlass.utils.check import (
alignment_or_default,
update_alignment,
calculate_smem_usage,
calculate_smem_usage_per_stage,
valid_cluster_shape,
valid_kernel_schedule,
valid_schedule,
valid_stage_count,
)

View File

@@ -39,29 +39,35 @@ import ctypes
import cutlass_bindings
import cutlass
from cutlass.backend.library import DataTypeSize, TileDescription
from cutlass.utils.datatypes import binding_type
def calculate_smem_usage_per_stage(tile_description, operation_kind):
def calculate_smem_usage_per_stage(td: TileDescription, operation_kind: cutlass.OperationKind) -> int:
"""
Returns the amount of shared memory in bytes consumed in a single stage of a kernel.
:param td: tile description to compute shared memory of
:type td: TileDescription
:param operation_kind: identifier for the type of operation being performed
:type operation_kind: cutlass.OperationKind
:return: number of bytes of shared memory consumed by a single stage
:rtype: int
"""
m, n, k = tile_description.threadblock_shape
m, n, k = td.threadblock_shape
if operation_kind == cutlass.OperationKind.Gemm:
stage_barrier_bytes = 32
return (
(DataTypeSize[tile_description.math_instruction.element_a] * m * k // 8)
+ (DataTypeSize[tile_description.math_instruction.element_b] * k * n // 8)
(DataTypeSize[td.math_instruction.element_a] * m * k // 8)
+ (DataTypeSize[td.math_instruction.element_b] * k * n // 8)
+ stage_barrier_bytes
)
else:
raise Exception(f"No available shared memory calculation for operation kind {operation.operation_kind}")
def calculate_smem_usage(operation):
def calculate_smem_usage(operation) -> int:
"""
Returns the amount of shared memory in bytes consumed by a kernel.
@@ -72,7 +78,11 @@ def calculate_smem_usage(operation):
return _per_stage * operation.tile_description.stages
def valid_stage_count(cc: int, td: TileDescription) -> tuple:
def valid_stage_count(
cc: int,
td: TileDescription,
element_C: cutlass.DataType = None,
element_D: cutlass.DataType = None) -> tuple:
"""
Checks whether a device with `cc` supports the number of stages within `tile_description`, both
based on raw limits on the number of stages and based on shared memory capacity
@@ -81,15 +91,26 @@ def valid_stage_count(cc: int, td: TileDescription) -> tuple:
:type cc: int
:param td: tile description to check
:type td: TileDescription
:param element_C: data type of operand C
:type element_C: cutlass.DataType
:param element_D: data type of operand D
:type element_D: cutlass.DataType
:return: tuple with the first element indicating whether the provided tile description is
valid for the provided device and the second element being an error message
:rtype: tuple
"""
if cc == 90 and (td.stages is None or td.stages == 0):
# Stage count of None or 0 for SM90 indicates that the CollectiveBuilder automatically
# determines the stage count to use. Thus, all settings are valid in these scenarios.
return (True, "")
if cc == 90:
if (td.stages is None or td.stages == 0):
# Stage count of None or 0 for SM90 indicates that the CollectiveBuilder automatically
# determines the stage count to use. Thus, all settings are valid in these scenarios.
return (True, "")
else:
cutlass.logger.warning(
"Setting an explicit stage count for SM90 kernels currently may "
"result in compilation errors if the combination of tile shape, "
"stage count, and shared memory requirement of the epilogue exceeds "
"the available shared memory per SM.")
if td.stages <= 0:
return (False, f"Stage counts must be positive integers. Tile description has stage count of {td.stages}.")
@@ -98,14 +119,20 @@ def valid_stage_count(cc: int, td: TileDescription) -> tuple:
return (False, f"Tile description has stage count of {td.stages}, "
f"but only 2 stages are supported on SM{cc}.")
# The calculation below does not consider shared memory used by the epilogue and, thus,
# only catches cases in which the mainloop exceeds the device's shared memory capacity.
# This is not a concern for CUTLASS 2.x kernels, for which the shared memory of the
# mainloop and epilogue is shared.
smem_per_stage = calculate_smem_usage_per_stage(td, cutlass.OperationKind.Gemm)
smem_usage_mainloop = (smem_per_stage * td.stages)
smem_arch = cutlass.SharedMemPerCC[cc] << 10
if (smem_per_stage * td.stages) > smem_arch:
if smem_usage_mainloop > smem_arch:
return ( False,
"Configuration uses too much shared memory. Consider reducing stage count or tile shape.\n"
f"Details: configuration uses {smem_per_stage} bytes of shared memory per stage, and "
f"{td.stages} stages for a total of {smem_per_stage * td.stages} bytes.\n"
f"The maxmium amoung of shared memory that can be used per block on CC {cc} is {smem_arch}.")
f"Details:\n"
f"Mainloop uses {smem_per_stage} bytes of shared memory per stage, and "
f"{td.stages} stages for a total of {smem_usage_mainloop} bytes.\n"
f"The maxmium amount of shared memory that can be used per block on CC {cc} is {smem_arch}.")
return (True, "")
@@ -153,21 +180,40 @@ def valid_cluster_shape(cc: int, cluster_shape: list) -> tuple:
return (True, "")
def valid_kernel_schedule(cc: int, kernel_schedule: cutlass.KernelScheduleType) -> tuple:
def valid_schedule(
cc: int,
kernel_schedule: cutlass.KernelScheduleType,
epilogue_schedule: cutlass.EpilogueScheduleType,
tile_scheduler: cutlass.TileSchedulerType) -> tuple:
"""
Checks whether a device with ``cc`` supports ``kernel_schedule``.
Checks that the kernel and epilogue schedules passed in are a valid combination for
a device of compute capability ``cc``.
:param cc: compute capability of device in question
:type cc: int
:param kernel_schedule: kernel schedule type
:type KernelScheduleType: cutlass.KernelScheduleType
:type kernel_schedule: cutlass.KernelScheduleType
:param epilogue_schedule: epilogue schedule type
:type epilogue_schedule: cutlass.EpilogueScheduleType
:param tile_scheduler: tile scheduler type
:type tile_scheduler: cutlass.TileSchedulerType
:return: tuple with the first element indicating whether the provided kernel schedule is
:return: tuple with the first element indicating whether the provided schedules are
valid for the provided device and the second element being an error message
:rtype: tuple
"""
if kernel_schedule != cutlass.KernelScheduleType.ScheduleAuto and cc < 90:
return (False, "Non-default kernel schedules are only supported on SM90 and beyond")
kernel_auto = (kernel_schedule == cutlass.KernelScheduleType.ScheduleAuto)
epilogue_auto = (epilogue_schedule == cutlass.EpilogueScheduleType.ScheduleAuto)
tile_scheduler_default = (tile_scheduler == cutlass.TileSchedulerType.Default)
if cc < 90 and not (kernel_auto and epilogue_auto and tile_scheduler_default):
return (False, "Non-default schedules are only supported on SM90 and beyond")
if (kernel_auto and not epilogue_auto) or (not kernel_auto and epilogue_auto):
return (False, "Kernel and epilogue schedules must either both be auto or neither be auto")
if not tile_scheduler_default:
if (tile_scheduler == cutlass.TileSchedulerType.StreamK) and (kernel_schedule != cutlass.KernelScheduleType.TmaWarpSpecializedCooperative):
return (False, "Stream-K tile scheduler is currently only supported with the cooperative kernel schedule")
return (True, "")
@@ -190,3 +236,26 @@ def alignment_or_default(alignment_provided: int, default_alignment: int) -> int
return alignment_provided
return default_alignment
def update_alignment(alignment_provided:int, default_alignment: int) -> int:
"""
Returns `alignment_provided` if it is set, otherwise `default_alignment` and checks
that `alignment_provided` does not exceed `default_alignment`.
:param alignment_provided: alignment preference specified. Can be None.
:type alignment_provided: int
:param default_alignment: alignment to use if `alignment_provided` is None
:type default_alignment: int
:return: alignment to use
:rtype: int
"""
if alignment_provided is not None:
if alignment_provided > default_alignment:
if alignment_provided % default_alignment == 0:
return default_alignment
raise Exception(f"Alignment {alignment_provided} exceeds the maximum supported of {default_alignment}.")
return alignment_provided
return default_alignment

View File

@@ -232,6 +232,8 @@ def library_layout(layout):
return cutlass.LayoutType.RowMajor
elif layout == cutlass_bindings.ColumnMajor:
return cutlass.LayoutType.ColumnMajor
elif layout == cutlass_bindings.TensorNHWC:
return cutlass.LayoutType.TensorNHWC
else:
raise Exception(f"No conversion available for layout {layout} to library layout.")
@@ -251,6 +253,8 @@ def binding_layout(layout):
return cutlass_bindings.RowMajor
elif layout == cutlass.LayoutType.ColumnMajor:
return cutlass_bindings.ColumnMajor
elif layout == cutlass.LayoutType.TensorNHWC:
return cutlass_bindings.TensorNHWC
else:
raise Exception(f"No conversion available for layout {layout} to Python-bound CUTLASS layout.")
@@ -279,6 +283,16 @@ def get_datatype_and_layout(tensor):
else:
raise Exception(f"Unable to convert tensor of type {type(tensor)} to Python-bound CUTLASS datatype and layout.")
def get_tensor_shape(tensor):
if (numpy_available and isinstance(tensor, np.ndarray)) or (
cupy_available and isinstance(tensor, cp.ndarray)
):
return tensor.shape
elif torch_available and isinstance(tensor, torch.Tensor):
size = tensor.size()
return (size[0], size[2], size[3], size[1])
else:
raise Exception(f"Unable to convert tensor of type {type(tensor)} to Python-bound CUTLASS datatype and layout.")
def binding_opclass(opclass: cutlass.OpcodeClass):
if opclass == cutlass.OpcodeClass.TensorOp:
@@ -299,7 +313,9 @@ def backend_math_operation(math_op: cutlass.MathOperation):
def construct_backend_td(td: cutlass.TileDescription,
kernel_schedule: cutlass.KernelScheduleType) -> TileDescription:
kernel_schedule: cutlass.KernelScheduleType,
epilogue_schedule: cutlass.EpilogueScheduleType,
tile_scheduler: cutlass.TileSchedulerType) -> TileDescription:
mi = td.math_instruction
backend_mi = MathInstruction(
mi.instruction_shape,
@@ -309,8 +325,9 @@ def construct_backend_td(td: cutlass.TileDescription,
binding_opclass(mi.opcode_class),
backend_math_operation(mi.math_operation)
)
cluster_shape = td.cluster_shape if hasattr(td, "cluster_shape") else [1, 1, 1]
return TileDescription(td.threadblock_shape, td.stages, td.warp_count,
backend_mi, td.cluster_shape, kernel_schedule)
backend_mi, cluster_shape, kernel_schedule, epilogue_schedule, tile_scheduler)
def td_from_profiler_op(op) -> TileDescription:
@@ -322,8 +339,10 @@ def td_from_profiler_op(op) -> TileDescription:
:returns: backend TileDescription
:rtype: cutlass.backend.TileDescription
"""
schedule = op.kernel_schedule if hasattr(op, 'kernel_schedule') else None
return construct_backend_td(op.tile_description, schedule)
kschedule = op.kernel_schedule if hasattr(op, 'kernel_schedule') else None
eschedule = op.epilogue_schedule if hasattr(op, 'epilogue_schedule') else None
tschedule = op.tile_scheduler if hasattr(op, 'tile_scheduler') else None
return construct_backend_td(op.tile_description, kschedule, eschedule, tschedule)
def td_from_profiler_td(td: cutlass.backend.TileDescription) -> TileDescription:
@@ -336,4 +355,16 @@ def td_from_profiler_td(td: cutlass.backend.TileDescription) -> TileDescription:
:returns: backend TileDescription
:rtype: cutlass.backend.TileDescription
"""
return construct_backend_td(td, kernel_schedule=None)
return construct_backend_td(td, kernel_schedule=None, epilogue_schedule=None, tile_scheduler=None)
def to_camel_case(snake_str):
return "".join(x.capitalize() for x in snake_str.lower().split("_"))
def getattr_enum(obj, attr_name):
# The attr_name is under the snake_case
camel_attr = to_camel_case(attr_name)
if hasattr(obj, camel_attr):
return getattr(obj, camel_attr)
else:
raise Exception(f"Invalid option: {attr_name}")

View File

@@ -112,6 +112,7 @@ library_dirs = [
cuda_install_path + '/lib64',
]
ext_modules = [
Pybind11Extension('cutlass_bindings',
['cutlass/cpp/cutlass_bindings.cpp'],