mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-20 06:48:59 +00:00
@@ -71,12 +71,15 @@ from library import (
|
||||
DataType,
|
||||
DataTypeSize,
|
||||
EpilogueFunctor,
|
||||
EpilogueScheduleSuffixes,
|
||||
EpilogueScheduleTag,
|
||||
EpilogueScheduleType,
|
||||
GemmKind,
|
||||
LayoutTag,
|
||||
LayoutType,
|
||||
KernelScheduleSuffixes,
|
||||
KernelScheduleType,
|
||||
KernelScheduleTag,
|
||||
KernelScheduleType,
|
||||
MathInstruction,
|
||||
MathOperation,
|
||||
OpcodeClass,
|
||||
@@ -85,6 +88,9 @@ from library import (
|
||||
SwizzlingFunctor,
|
||||
TensorDescription,
|
||||
TileDescription,
|
||||
TileSchedulerSuffixes,
|
||||
TileSchedulerTag,
|
||||
TileSchedulerType
|
||||
)
|
||||
|
||||
this = sys.modules[__name__]
|
||||
@@ -106,11 +112,12 @@ from cutlass.backend.utils.device import device_cc
|
||||
|
||||
this.option_registry = OptionRegistry(device_cc())
|
||||
|
||||
this.__version__ = '3.1.0'
|
||||
this.__version__ = '3.2.0'
|
||||
|
||||
from cutlass.backend import get_memory_pool
|
||||
from cutlass.emit.pytorch import pytorch
|
||||
from cutlass.op.gemm import Gemm
|
||||
from cutlass.op.conv import Conv2d, Conv2dFprop, Conv2dDgrad, Conv2dWgrad
|
||||
from cutlass.op.gemm_grouped import GroupedGemm
|
||||
from cutlass.op.op import OperationBase
|
||||
|
||||
|
||||
@@ -118,6 +118,7 @@ class GenericMainloopArguments3x_(ctypes.Structure):
|
||||
("stride_A", StrideBatched_),
|
||||
("ptr_B", ctypes.c_void_p),
|
||||
("stride_B", StrideBatched_),
|
||||
("mma_promotion_interval", ctypes.c_int)
|
||||
]
|
||||
|
||||
|
||||
@@ -148,12 +149,14 @@ def get_mainloop_arguments_3x(
|
||||
("stride_A", StrideBatched_),
|
||||
("ptr_B", ctypes.c_void_p),
|
||||
("stride_B", StrideBatched_),
|
||||
("mma_promotion_interval", ctypes.c_int)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def from_generic_mainloop_args(args: GenericMainloopArguments3x_):
|
||||
return _MainloopArgumentsTma(
|
||||
args.ptr_A, args.stride_A, args.ptr_B, args.stride_B,
|
||||
args.mma_promotion_interval
|
||||
)
|
||||
|
||||
class _MainloopArgumentsMultistage(ctypes.Structure):
|
||||
@@ -203,15 +206,23 @@ def get_gemm_arguments_3x(mainloop_arguments, epilogue_functor):
|
||||
("stride_D", StrideBatched_),
|
||||
]
|
||||
|
||||
class _HardwareInfo(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("device_id", ctypes.c_int),
|
||||
("sm_count", ctypes.c_int)
|
||||
]
|
||||
|
||||
class _GemmArguments(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("mode", ctypes.c_int),
|
||||
("problem_size", GemmCoordBatched_),
|
||||
("mainloop", mainloop_arguments),
|
||||
("epilogue", _EpilogueArguments)
|
||||
("epilogue", _EpilogueArguments),
|
||||
("hw_info", _HardwareInfo),
|
||||
("splits", ctypes.c_int)
|
||||
]
|
||||
|
||||
return _GemmArguments, _EpilogueArguments, _EpilogueOutputOpParams
|
||||
return _GemmArguments, _EpilogueArguments, _EpilogueOutputOpParams, _HardwareInfo
|
||||
|
||||
|
||||
def get_gemm_arguments(epilogue_functor):
|
||||
|
||||
@@ -39,16 +39,37 @@ import tempfile
|
||||
from cuda import cuda, nvrtc
|
||||
import cutlass_bindings
|
||||
|
||||
from cutlass import CACHE_FILE, CUDA_INSTALL_PATH, CUTLASS_PATH
|
||||
from cutlass import CACHE_FILE, CUDA_INSTALL_PATH, CUTLASS_PATH, logger
|
||||
from cutlass.backend.gemm_operation import GemmOperationUniversal
|
||||
from cutlass.backend.library import ApiVersion
|
||||
from cutlass.backend.utils.device import device_cc
|
||||
from cutlass.backend.utils.software import SubstituteTemplate
|
||||
import subprocess
|
||||
|
||||
IncludeTemplate = r"""#include "${include}"
|
||||
"""
|
||||
|
||||
|
||||
def compile_with_nvcc(cmd, source, error_file):
|
||||
succeed = True
|
||||
try:
|
||||
subprocess.check_output(cmd, stderr=subprocess.STDOUT, shell=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
error_message = e.output.decode()
|
||||
with open(error_file, "w") as error_out:
|
||||
error_log = "Compilation error for the following kernel: \n"
|
||||
error_log += source
|
||||
error_log += "\nError Message:\n"
|
||||
error_log += error_message
|
||||
error_out.write(error_log)
|
||||
succeed = False
|
||||
if not succeed:
|
||||
# Print the error log to stdout if log level is set to warning or higher
|
||||
# verbosity. Otherwise, simply point to the error log file.
|
||||
logger.warning(error_log)
|
||||
raise Exception(f"Invalid Kernel. See '{error_file}' for details.")
|
||||
|
||||
|
||||
class CompilationOptions:
|
||||
"""
|
||||
Compilation options.
|
||||
@@ -129,20 +150,24 @@ class ArtifactManager:
|
||||
connection.commit()
|
||||
cursor.close()
|
||||
|
||||
self._nvrtc_compile_options = ["-std=c++17", "-default-device"]
|
||||
self._nvcc_compile_options = [
|
||||
"-std=c++17",
|
||||
"--expt-relaxed-constexpr",
|
||||
"-Xcudafe --diag_suppress=esa_on_defaulted_function_ignored",
|
||||
]
|
||||
self.nvcc()
|
||||
self.compiled_cache_device = cutlass_bindings.CompileCache()
|
||||
self.compiled_cache_host = cutlass_bindings.CompileCache()
|
||||
|
||||
def nvrtc(self):
|
||||
self.backend = "nvrtc"
|
||||
self.default_compile_options = ["-std=c++17", "-default-device"]
|
||||
self.default_compile_options = self._nvrtc_compile_options
|
||||
|
||||
def nvcc(self):
|
||||
self.backend = "nvcc"
|
||||
self.default_compile_options = [
|
||||
"-std=c++17",
|
||||
"--expt-relaxed-constexpr",
|
||||
"-Xcudafe --diag_suppress=esa_on_defaulted_function_ignored",
|
||||
]
|
||||
self.default_compile_options = self._nvcc_compile_options
|
||||
|
||||
def insert_operation(self, op_key, cubin, hostfile, op_name, op_attrs):
|
||||
connection = sqlite3.connect(CACHE_FILE)
|
||||
cursor = connection.cursor()
|
||||
@@ -200,7 +225,7 @@ class ArtifactManager:
|
||||
self.compiled_cache_host.insert(key, compiled_host_fns)
|
||||
return True
|
||||
|
||||
def emit_compile_(self, operation_list, compilation_options):
|
||||
def emit_compile_(self, operation_list, compilation_options, host_compilation_options):
|
||||
"""
|
||||
Compile a list of kernels and store them into database
|
||||
"""
|
||||
@@ -299,7 +324,7 @@ class ArtifactManager:
|
||||
"tarfile": temp_cubin.name,
|
||||
}
|
||||
cmd = SubstituteTemplate(cmd_template, values)
|
||||
os.system(cmd)
|
||||
compile_with_nvcc(cmd, source_buffer_device, "./cutlass_python_compilation_device_error.txt")
|
||||
|
||||
# load the cubin image
|
||||
with open(temp_cubin.name, "rb") as file:
|
||||
@@ -314,7 +339,7 @@ class ArtifactManager:
|
||||
cmd_template,
|
||||
{
|
||||
"cuda_install_path": CUDA_INSTALL_PATH,
|
||||
"options": compilation_options.get_str(),
|
||||
"options": host_compilation_options.get_str(),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -323,29 +348,31 @@ class ArtifactManager:
|
||||
prefix="host_func", suffix=".so", delete=True)
|
||||
|
||||
cmd += " - -shared -o %s -lcudart -lcuda" % temp.name
|
||||
os.system(cmd)
|
||||
compile_with_nvcc(cmd, source_buffer_host, error_file="./cutlass_python_compilation_host_error.txt")
|
||||
host_lib = ctypes.CDLL(temp.name)
|
||||
|
||||
return cubin_image, host_lib, temp
|
||||
|
||||
def add_module(self, operations, compile_options=None):
|
||||
def add_module(self, operations, compile_options=None, bypass_cache=False):
|
||||
"""
|
||||
Insert a new compiled device module
|
||||
"""
|
||||
if compile_options is None:
|
||||
include_paths = [
|
||||
CUDA_INSTALL_PATH + "/include",
|
||||
CUTLASS_PATH + "/include",
|
||||
CUTLASS_PATH + "/tools/util/include",
|
||||
CUTLASS_PATH + "/python/cutlass/cpp/include",
|
||||
]
|
||||
include_paths = [
|
||||
CUDA_INSTALL_PATH + "/include",
|
||||
CUTLASS_PATH + "/include",
|
||||
CUTLASS_PATH + "/tools/util/include",
|
||||
CUTLASS_PATH + "/python/cutlass/cpp/include",
|
||||
]
|
||||
|
||||
if device_cc() is not None:
|
||||
arch = device_cc()
|
||||
else:
|
||||
# Find the maximum arch tag among the provided operations and compile for that target.
|
||||
# Since we are compiling to .cubin files, only one architecture may be specified.
|
||||
arch = max([op.arch for op in operations])
|
||||
if device_cc() is not None:
|
||||
arch = device_cc()
|
||||
else:
|
||||
# Find the maximum arch tag among the provided operations and compile for that target.
|
||||
# Since we are compiling to .cubin files, only one architecture may be specified.
|
||||
arch = max([op.arch for op in operations])
|
||||
host_compile_options = CompilationOptions(
|
||||
self._nvcc_compile_options, arch, include_paths)
|
||||
if compile_options is None:
|
||||
compile_options = CompilationOptions(
|
||||
self.default_compile_options, arch, include_paths)
|
||||
# save the cubin
|
||||
@@ -357,7 +384,7 @@ class ArtifactManager:
|
||||
# step 1: check if the operation is in cache
|
||||
compiled_kernel = self.compiled_cache_device.at(key)
|
||||
|
||||
if compiled_kernel is None:
|
||||
if compiled_kernel is None and not bypass_cache:
|
||||
hit = self.load_operation(key, getattr( operation.rt_module, "extra_funcs", {}))
|
||||
if hit:
|
||||
compiled_kernel = self.compiled_cache_device.at(key)
|
||||
@@ -375,7 +402,7 @@ class ArtifactManager:
|
||||
|
||||
if len(operation_list) > 0:
|
||||
cubin_image, host_lib, host_file = self.emit_compile_(
|
||||
operation_list, compile_options)
|
||||
operation_list, compile_options, host_compile_options)
|
||||
|
||||
err, module = cuda.cuModuleLoadData(cubin_image)
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
|
||||
@@ -41,6 +41,7 @@ import numpy as np
|
||||
from cutlass.backend.arguments import ArgumentBase
|
||||
from cutlass.backend.c_types import Conv2DProblemSize, TensorRef_, get_conv2d_arguments
|
||||
from cutlass.backend.library import (
|
||||
EmissionType,
|
||||
ConvKindNames,
|
||||
ConvKindTag,
|
||||
DataTypeNames,
|
||||
@@ -123,17 +124,17 @@ class Conv2dArguments(ArgumentBase):
|
||||
super().__init__(A, B, C, D, **kwargs)
|
||||
# preprocessing output ops
|
||||
|
||||
if "output_op" in kwargs.keys() and split_k_mode != cutlass_bindings.conv.SplitKMode.Parallel:
|
||||
self.output_op = kwargs["output_op"]
|
||||
else:
|
||||
self.output_op = self.operation.epilogue_type(1.0, 0.0)
|
||||
|
||||
if "split_k_slices" in kwargs.keys():
|
||||
if "split_k_slices" in kwargs.keys() and kwargs["split_k_slices"] > 1:
|
||||
self.split_k_mode = split_k_mode
|
||||
self.split_k_slices = kwargs["split_k_slices"]
|
||||
else:
|
||||
self.split_k_mode = cutlass_bindings.conv.SplitKMode.Serial
|
||||
self.split_k_slices = 1
|
||||
|
||||
if "output_op" in kwargs.keys() and self.split_k_mode != cutlass_bindings.conv.SplitKMode.Parallel:
|
||||
self.output_op = kwargs["output_op"]
|
||||
else:
|
||||
self.output_op = self.operation.epilogue_type(1.0, 0.0)
|
||||
|
||||
#: problem_size
|
||||
self.problem_size: cutlass_bindings.conv.Conv2dProblemSize = problem_size
|
||||
@@ -419,7 +420,9 @@ class Conv2dOperation:
|
||||
C: TensorDescription,
|
||||
stride_support,
|
||||
epilogue_functor,
|
||||
swizzling_functor=cutlass_bindings.IdentitySwizzle1
|
||||
swizzling_functor=cutlass_bindings.IdentitySwizzle1,
|
||||
emission_type=EmissionType.Kernel,
|
||||
**kwargs
|
||||
):
|
||||
self.operation_kind: OperationKind = OperationKind.Conv2d
|
||||
self.arch: int = arch
|
||||
@@ -432,6 +435,8 @@ class Conv2dOperation:
|
||||
self.iterator_algorithm = iterator_algorithm
|
||||
self.stride_support = stride_support
|
||||
self.swizzling_functor = swizzling_functor()
|
||||
|
||||
self.emission_type = emission_type
|
||||
|
||||
self.rt_module: Conv2dRT = Conv2dRT(self)
|
||||
self.argument_type = self.rt_module.argument_type
|
||||
@@ -562,6 +567,18 @@ class Conv2dOperation:
|
||||
|
||||
return accum
|
||||
|
||||
def device_op(self):
|
||||
"""
|
||||
Returns a new Conv2dOperation object that is constructed with emission type
|
||||
``EmissionType.Device``.
|
||||
|
||||
:return: operation ready for device-level code emission
|
||||
:rtype: Conv2dOperation
|
||||
"""
|
||||
return Conv2dOperation(
|
||||
self.conv_kind, self.iterator_algorithm, self.arch, self.tile_description,
|
||||
self.A, self.B, self.C, self.stride_support, self.epilogue_functor, type(self.swizzling_functor),
|
||||
emission_type=EmissionType.Device)
|
||||
|
||||
###################################################################################################
|
||||
#
|
||||
@@ -596,7 +613,7 @@ typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}<
|
||||
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >,
|
||||
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
||||
${epilogue_functor},
|
||||
${swizzling_functor}, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>,
|
||||
${swizzling_functor},
|
||||
${stages},
|
||||
${math_operator},
|
||||
${iterator_algorithm},
|
||||
@@ -608,6 +625,36 @@ typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}<
|
||||
struct ${operation_name}${operation_suffix}:
|
||||
public ${operation_name}_base { };
|
||||
|
||||
"""
|
||||
|
||||
self.template_device = """
|
||||
// Conv2d operation ${operation_name}
|
||||
|
||||
using Conv2d${conv_kind_name}Kernel = typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}<
|
||||
${element_a},
|
||||
${layout_a},
|
||||
${element_b},
|
||||
${layout_b},
|
||||
${element_c},
|
||||
${layout_c},
|
||||
${element_accumulator},
|
||||
${opcode_class},
|
||||
${arch},
|
||||
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
||||
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >,
|
||||
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
||||
${epilogue_functor},
|
||||
${swizzling_functor},
|
||||
${stages},
|
||||
${math_operator},
|
||||
${iterator_algorithm},
|
||||
${stride_support},
|
||||
${align_a},
|
||||
${align_b}
|
||||
>::Kernel;
|
||||
|
||||
using DeviceKernel =
|
||||
typename cutlass::conv::device::ImplicitGemmConvolution<Conv2d${conv_kind_name}Kernel>;
|
||||
"""
|
||||
|
||||
def emit(self, operation):
|
||||
@@ -651,5 +698,10 @@ struct ${operation_name}${operation_suffix}:
|
||||
"align_a": str(operation.A.alignment),
|
||||
"align_b": str(operation.B.alignment),
|
||||
}
|
||||
|
||||
if operation.emission_type == EmissionType.Kernel:
|
||||
conv2d_template = self.template
|
||||
else:
|
||||
conv2d_template = self.template_device
|
||||
|
||||
return SubstituteTemplate(self.template, values)
|
||||
return SubstituteTemplate(conv2d_template, values)
|
||||
|
||||
@@ -39,7 +39,17 @@ import cutlass_bindings
|
||||
import numpy as np
|
||||
import rmm
|
||||
|
||||
from cutlass import KernelScheduleSuffixes, KernelScheduleTag, KernelScheduleType
|
||||
from cutlass import (
|
||||
EpilogueScheduleSuffixes,
|
||||
EpilogueScheduleTag,
|
||||
EpilogueScheduleType,
|
||||
KernelScheduleSuffixes,
|
||||
KernelScheduleTag,
|
||||
KernelScheduleType,
|
||||
TileSchedulerSuffixes,
|
||||
TileSchedulerTag,
|
||||
TileSchedulerType
|
||||
)
|
||||
from cutlass.backend.arguments import ArgumentBase
|
||||
from cutlass.backend.c_types import (
|
||||
GemmCoord_,
|
||||
@@ -55,6 +65,7 @@ from cutlass.backend.c_types import (
|
||||
)
|
||||
from cutlass.backend.library import (
|
||||
ApiVersion,
|
||||
EmissionType,
|
||||
ComplexTransformTag,
|
||||
DataTypeNames,
|
||||
DataTypeSize,
|
||||
@@ -548,6 +559,7 @@ class GemmArguments3x(GemmArguments2x):
|
||||
stride_A,
|
||||
int(self.ptr_B),
|
||||
stride_B,
|
||||
4 # mma_promotion_interval
|
||||
)
|
||||
|
||||
# Set of mainloop arguments needed for this kernel
|
||||
@@ -561,11 +573,15 @@ class GemmArguments3x(GemmArguments2x):
|
||||
stride_D,
|
||||
)
|
||||
|
||||
# Set hardware info
|
||||
hw_info = self.operation.rt_module.hw_info(0, device_sm_count())
|
||||
|
||||
self.arguments = self.operation.argument_type(
|
||||
self.gemm_mode,
|
||||
problem_size_,
|
||||
mainloop,
|
||||
epilogue,
|
||||
hw_info,
|
||||
)
|
||||
return self.arguments
|
||||
|
||||
@@ -1102,6 +1118,11 @@ extern "C" {
|
||||
|
||||
using GemmType = ${operation_name}_base;
|
||||
|
||||
// Get the workspace size
|
||||
uint64_t ${operation_name}_get_kernel_workspace_size(GemmType::Arguments* argument) {
|
||||
return GemmType::get_workspace_size(*argument);
|
||||
}
|
||||
|
||||
// Get the params as byte array
|
||||
char* ${operation_name}_get_params(GemmType::Arguments* argument, int* workspace){
|
||||
GemmType::Params params = GemmType::to_underlying_arguments(*argument, workspace);
|
||||
@@ -1118,7 +1139,7 @@ extern "C" {
|
||||
uint64_t ${operation_name}_get_persistent_tiled_blk_shape_mnl(GemmType::ProblemShape problem) {
|
||||
auto problem_shape_MNKL = append<4>(problem, Int<1>{});
|
||||
auto [problem_blocks_m, problem_blocks_n, problem_blocks_l] =
|
||||
cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::get_tiled_blk_shape_mnl(
|
||||
cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::get_tiled_cta_shape_mnl(
|
||||
problem_shape_MNKL, GemmType::TileShape{}, GemmType::DispatchPolicy::ClusterShape{});
|
||||
return problem_blocks_m * problem_blocks_n * problem_blocks_l;
|
||||
}
|
||||
@@ -1141,7 +1162,8 @@ extern "C" {
|
||||
self.extra_funcs = {
|
||||
"get_grid_shape": dim3_,
|
||||
"get_block_shape": dim3_,
|
||||
"get_persistent_tiled_blk_shape_mnl": ctypes.c_uint64
|
||||
"get_persistent_tiled_blk_shape_mnl": ctypes.c_uint64,
|
||||
"get_kernel_workspace_size": ctypes.c_uint64
|
||||
}
|
||||
self.emitter = EmitGemmUniversalInstance3x("_type")
|
||||
self.mainloop_args = get_mainloop_arguments_3x(
|
||||
@@ -1151,7 +1173,10 @@ extern "C" {
|
||||
operation.A.alignment,
|
||||
operation.B.alignment
|
||||
)
|
||||
self.argument_type, self.epilogue_args, self.epilogue_type = get_gemm_arguments_3x(self.mainloop_args, operation.epilogue_functor)
|
||||
self.argument_type, self.epilogue_args, self.epilogue_type, self.hw_info = get_gemm_arguments_3x(self.mainloop_args, operation.epilogue_functor)
|
||||
|
||||
def get_device_workspace_size(self, arguments: GemmArguments3x):
|
||||
return self.get_kernel_workspace_size(ctypes.byref(arguments.get_arguments()))
|
||||
|
||||
|
||||
class EmitGemmUniversalInstance3x:
|
||||
@@ -1183,7 +1208,7 @@ using CollectiveEpilogue =
|
||||
${element_accumulator}, ${element_epilogue},
|
||||
${element_c}, ${layout_c}, ${align_c},
|
||||
${element_d}, ${layout_d}, ${align_d},
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
${epilogue_schedule}
|
||||
>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop =
|
||||
@@ -1202,7 +1227,8 @@ using CollectiveMainloop =
|
||||
using ${operation_name}_base = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue
|
||||
CollectiveEpilogue,
|
||||
${tile_scheduler}
|
||||
>;
|
||||
|
||||
// Define named type
|
||||
@@ -1233,9 +1259,15 @@ using DeviceKernel = cutlass::gemm::device::GemmUniversalAdapter<${operation_nam
|
||||
else:
|
||||
gemm_template = self.gemm_template_device
|
||||
|
||||
schedule = KernelScheduleType.ScheduleAuto
|
||||
kschedule = KernelScheduleType.ScheduleAuto
|
||||
eschedule = EpilogueScheduleType.ScheduleAuto
|
||||
tschedule = TileSchedulerType.Default
|
||||
if operation.tile_description.kernel_schedule is not None:
|
||||
schedule = operation.tile_description.kernel_schedule
|
||||
kschedule = operation.tile_description.kernel_schedule
|
||||
if operation.tile_description.epilogue_schedule is not None:
|
||||
eschedule = operation.tile_description.epilogue_schedule
|
||||
if operation.tile_description.tile_scheduler is not None:
|
||||
tschedule = operation.tile_description.tile_scheduler
|
||||
|
||||
values = {
|
||||
"operation_name": operation.procedural_name(),
|
||||
@@ -1264,7 +1296,9 @@ using DeviceKernel = cutlass::gemm::device::GemmUniversalAdapter<${operation_nam
|
||||
"align_c": str(operation.C.alignment),
|
||||
"align_d": str(operation.C.alignment),
|
||||
"stage_count_type": stage_count_type,
|
||||
"kernel_schedule": KernelScheduleTag[schedule],
|
||||
"kernel_schedule": KernelScheduleTag[kschedule],
|
||||
"epilogue_schedule": EpilogueScheduleTag[eschedule],
|
||||
"tile_scheduler": TileSchedulerTag[tschedule]
|
||||
}
|
||||
|
||||
values["epilogue_functor"] = operation.epilogue_functor.emit()
|
||||
@@ -1382,15 +1416,6 @@ ${operation_name}(${operation_name}${operation_suffix}::Params params) {
|
||||
################################################################################
|
||||
|
||||
|
||||
class EmissionType(enum.Enum):
|
||||
"""
|
||||
Tags for whether to emit a kernel- or device-level operation
|
||||
"""
|
||||
|
||||
Kernel = enum_auto()
|
||||
Device = enum_auto()
|
||||
|
||||
|
||||
class GemmOperationBase:
|
||||
"""
|
||||
CUTLASS GEMM operation
|
||||
@@ -1595,11 +1620,18 @@ class GemmOperationBase:
|
||||
else:
|
||||
return KernelScheduleSuffixes[self.tile_description.kernel_schedule]
|
||||
|
||||
# Generates a short string representing underlying epilogue schedule type
|
||||
def epilogue_schedule_name_3x(self):
|
||||
if self.tile_description.epilogue_schedule is None:
|
||||
return EpilogueScheduleSuffixes[EpilogueScheduleType.ScheduleAuto]
|
||||
else:
|
||||
return EpilogueScheduleSuffixes[self.tile_description.epilogue_schedule]
|
||||
|
||||
def procedural_name(self):
|
||||
"""The full procedural name indicates architecture, extended name, tile size, and layout."""
|
||||
opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
|
||||
if self.api == ApiVersion.v3x and self.arch >= 90:
|
||||
kernel_name_template = "cutlass{p}_sm{ar}_{op}_{ex}_{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{l}_{s}_align{al}{k}"
|
||||
kernel_name_template = "cutlass{p}_sm{ar}_{op}_{ex}_{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{l}_{s}_align{al}{k}{e}"
|
||||
return kernel_name_template.format(
|
||||
p=self.prefix,
|
||||
ar=self.arch,
|
||||
@@ -1614,7 +1646,8 @@ class GemmOperationBase:
|
||||
l=self.tile_description.stages,
|
||||
s=self.layout_name_3x(),
|
||||
al=str(self.A.alignment),
|
||||
k=self.kernel_schedule_name_3x()
|
||||
k=self.kernel_schedule_name_3x(),
|
||||
e=self.epilogue_schedule_name_3x()
|
||||
)
|
||||
else:
|
||||
threadblock = self.tile_description.procedural_name()
|
||||
|
||||
@@ -38,7 +38,7 @@ but uses the Pybind-bound CUTLASS data types as many keys to the dictionary.
|
||||
import enum
|
||||
|
||||
import cutlass_bindings
|
||||
from cutlass import KernelScheduleType
|
||||
from cutlass import EpilogueScheduleType, KernelScheduleType, TileSchedulerType
|
||||
|
||||
|
||||
# The following block implements enum.auto() for Python 3.5 variants that don't include it such
|
||||
@@ -554,7 +554,9 @@ class TileDescription:
|
||||
warp_count,
|
||||
math_instruction,
|
||||
cluster_shape=[1, 1, 1],
|
||||
kernel_schedule: KernelScheduleType = None
|
||||
kernel_schedule: KernelScheduleType = None,
|
||||
epilogue_schedule: EpilogueScheduleType = None,
|
||||
tile_scheduler: TileSchedulerType = None,
|
||||
):
|
||||
"""
|
||||
:param threadblock_shape: shape of a threadblock tyle
|
||||
@@ -568,18 +570,61 @@ class TileDescription:
|
||||
:type math_instruction: MathInstruction
|
||||
:param cluster_shape: number of threadblocks in the [X, Y, Z] dimensions of a threadblock cluster
|
||||
:param kernel_schedule: type of kernel schedule to use (only available for SM90+)
|
||||
:type kernel_schedule: cutlass.backend.KernelScheduleType
|
||||
:type kernel_schedule: cutlass.KernelScheduleType
|
||||
:param epilogue_schedule: type of epilogue schedule to use (only available for SM90+)
|
||||
:type epilogue_schedule: cutlass.EpilogueScheduleType
|
||||
:param tile_scheduler: type of tile scheduler to use (only available for SM90+)
|
||||
:type tile_scheduler: cutlass.TileSchedulerType
|
||||
"""
|
||||
if ((kernel_schedule is None and epilogue_schedule is not None) or
|
||||
(kernel_schedule is not None and epilogue_schedule is None)):
|
||||
raise Exception("Kernel and epilogue schedule must either both be Auto or neither be Auto.")
|
||||
|
||||
self.threadblock_shape = threadblock_shape
|
||||
self.cluster_shape = cluster_shape
|
||||
self.kernel_schedule = kernel_schedule
|
||||
self.stages: int = stages
|
||||
self.epilogue_schedule = epilogue_schedule
|
||||
self.tile_scheduler = tile_scheduler
|
||||
self.stages = stages
|
||||
|
||||
self.math_instruction = math_instruction
|
||||
self.instruction_shape = math_instruction.instruction_shape
|
||||
|
||||
# Number of warps along x, y, z directions
|
||||
self.warp_count = warp_count
|
||||
|
||||
def clone_and_update(self, td: dict):
|
||||
attrs = {
|
||||
"cluster_shape": None,
|
||||
"threadblock_shape": None,
|
||||
"warp_count": None,
|
||||
"stages": None,
|
||||
"instruction_shape": None,
|
||||
"kernel_schedule": None,
|
||||
"epilogue_schedule": None,
|
||||
"tile_scheduler": None
|
||||
}
|
||||
for key in attrs.keys():
|
||||
if key in td.keys():
|
||||
attrs[key] = td[key]
|
||||
else:
|
||||
attrs[key] = getattr(self, key)
|
||||
|
||||
mi = MathInstruction(
|
||||
attrs["instruction_shape"],
|
||||
self.math_instruction.element_a,
|
||||
self.math_instruction.element_b,
|
||||
self.math_instruction.element_accumulator,
|
||||
self.math_instruction.opcode_class,
|
||||
self.math_instruction.math_operation
|
||||
)
|
||||
|
||||
return TileDescription(
|
||||
attrs["threadblock_shape"], attrs["stages"],
|
||||
attrs["warp_count"], mi, attrs["cluster_shape"],
|
||||
attrs["kernel_schedule"], attrs["epilogue_schedule"]
|
||||
)
|
||||
|
||||
@property
|
||||
def num_threads(self):
|
||||
"""
|
||||
@@ -622,16 +667,30 @@ class TileDescription:
|
||||
:return: contents of tile description
|
||||
:rtype: str
|
||||
"""
|
||||
schedule = KernelScheduleType.ScheduleAuto
|
||||
if self.kernel_schedule is not None:
|
||||
schedule = self.kernel_schedule
|
||||
kschedule = self.kernel_schedule
|
||||
else:
|
||||
kschedule = KernelScheduleType.ScheduleAuto
|
||||
|
||||
if self.epilogue_schedule is not None:
|
||||
eschedule = self.epilogue_schedule
|
||||
else:
|
||||
eschedule = EpilogueScheduleType.ScheduleAuto
|
||||
|
||||
if self.tile_scheduler is not None:
|
||||
tschedule = self.tile_scheduler.name
|
||||
else:
|
||||
tschedule = "None"
|
||||
return f"""
|
||||
{{
|
||||
ClusterShape: {self.cluster_shape}
|
||||
ThreadblockShape: {self.threadblock_shape}
|
||||
WarpCount: {self.warp_count}
|
||||
Stages: {self.stages if self.stages is not None else 'Auto'}
|
||||
Kernel schedule: {schedule.name}
|
||||
InstructionShape: {self.math_instruction.instruction_shape}
|
||||
Kernel schedule: {kschedule.name}
|
||||
Epilogue schedule: {kschedule.name}
|
||||
TileScheduler: {tschedule}
|
||||
}}"""
|
||||
|
||||
|
||||
@@ -712,3 +771,12 @@ def api_version(arch, opclass, datatype):
|
||||
return ApiVersion.v3x
|
||||
else:
|
||||
return ApiVersion.v2x
|
||||
|
||||
|
||||
class EmissionType(enum.Enum):
|
||||
"""
|
||||
Tags for whether to emit a kernel- or device-level operation
|
||||
"""
|
||||
|
||||
Kernel = enum_auto()
|
||||
Device = enum_auto()
|
||||
|
||||
@@ -38,7 +38,7 @@ from bfloat16 import bfloat16
|
||||
import cutlass_bindings
|
||||
import numpy as np
|
||||
|
||||
from cutlass.backend.compiler import ArtifactManager
|
||||
from cutlass.backend import compiler
|
||||
from cutlass.backend.conv2d_operation import Conv2dArguments, Conv2dOperation
|
||||
from cutlass.backend.library import DataTypeSize, ShortDataTypeNames, StrideSupport
|
||||
from cutlass.backend.memory_manager import get_allocated_size
|
||||
@@ -127,7 +127,6 @@ def getTensorView(tensor, tensor_layout, conv_kind, problem_size, operand):
|
||||
raise ValueError("unsupported data type")
|
||||
|
||||
|
||||
# @typechecked
|
||||
class Conv2dLauncher:
|
||||
"""
|
||||
Launcher that runs the operation on given problem size
|
||||
@@ -142,6 +141,7 @@ class Conv2dLauncher:
|
||||
profiling=False,
|
||||
warmup_iterations=500,
|
||||
iterations=500,
|
||||
compilation_mode="nvcc",
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.enable_cached_results = True
|
||||
@@ -176,7 +176,14 @@ class Conv2dLauncher:
|
||||
# Compile the operator
|
||||
#
|
||||
|
||||
ArtifactManager().add_module([operation, self.reduction_operation])
|
||||
if compilation_mode == "nvcc":
|
||||
compiler.nvcc()
|
||||
elif compilation_mode == "nvrtc":
|
||||
compiler.nvrtc()
|
||||
else:
|
||||
raise Exception(f"Unexpected compilation mode {compilation_mode}")
|
||||
|
||||
compiler.add_module([operation, self.reduction_operation])
|
||||
|
||||
self.operation = operation
|
||||
|
||||
@@ -195,14 +202,14 @@ class Conv2dLauncher:
|
||||
element_size = DataTypeSize[operation.A.element]
|
||||
|
||||
if element_size <= 8:
|
||||
self.scope = 1
|
||||
self.randomization_max = 1
|
||||
elif element_size == 16:
|
||||
if accumulator_size <= 16:
|
||||
self.scope = 2
|
||||
self.randomization_max = 2
|
||||
else:
|
||||
self.scope = 4
|
||||
self.randomization_max = 4
|
||||
else:
|
||||
self.scope = 7
|
||||
self.randomization_max = 7
|
||||
|
||||
# Seed
|
||||
self.seed = seed
|
||||
@@ -263,12 +270,12 @@ class Conv2dLauncher:
|
||||
if dtype in [np.float32, np.float16, bfloat16, np.float64]:
|
||||
return np.ceil(
|
||||
np.random.uniform(
|
||||
low=-self.scope - 0.5, high=self.scope - 0.5, size=size
|
||||
low=-self.randomization_max - 0.5, high=self.randomization_max - 0.5, size=size
|
||||
).astype(dtype)
|
||||
)
|
||||
else:
|
||||
return np.random.uniform(
|
||||
low=-self.scope - 1, high=self.scope + 1, size=size
|
||||
low=-self.randomization_max - 1, high=self.randomization_max + 1, size=size
|
||||
).astype(dtype)
|
||||
|
||||
def eq_gemm_size(self, problem_size):
|
||||
@@ -624,13 +631,15 @@ class Conv2dLauncher:
|
||||
############################################################################################################
|
||||
|
||||
|
||||
def test_all_conv2d(operation: Conv2dOperation, conv_test_sizes=[], interleaved=False):
|
||||
passed = True
|
||||
#
|
||||
# Testbed object
|
||||
#
|
||||
def test_all_conv2d_from_compilation_mode(
|
||||
operation: Conv2dOperation,
|
||||
conv_test_sizes,
|
||||
interleaved,
|
||||
compilation_mode):
|
||||
|
||||
testbed = Conv2dLauncher(operation, interleaved=interleaved)
|
||||
passed = True
|
||||
|
||||
testbed = Conv2dLauncher(operation, interleaved=interleaved, compilation_mode=compilation_mode)
|
||||
|
||||
#
|
||||
# Get conv problem sizes to run conv operator
|
||||
@@ -781,3 +790,18 @@ def test_all_conv2d(operation: Conv2dOperation, conv_test_sizes=[], interleaved=
|
||||
)
|
||||
|
||||
return passed
|
||||
|
||||
|
||||
def test_all_conv2d(
|
||||
operation: Conv2dOperation,
|
||||
conv_test_sizes=[],
|
||||
interleaved=False,
|
||||
compilation_modes=["nvcc", "nvrtc"]):
|
||||
|
||||
for compilation_mode in compilation_modes:
|
||||
passed = test_all_conv2d_from_compilation_mode(operation, conv_test_sizes, interleaved, compilation_mode)
|
||||
|
||||
if not passed:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@@ -177,6 +177,7 @@ class GemmUniversalLauncher:
|
||||
profiling=False,
|
||||
warmup_iterations=500,
|
||||
iterations=500,
|
||||
compiler_mode: str = "nvcc",
|
||||
**kwargs,
|
||||
) -> None:
|
||||
# create the reduction kernel
|
||||
@@ -209,13 +210,19 @@ class GemmUniversalLauncher:
|
||||
#
|
||||
# Compile the operator
|
||||
#
|
||||
if compiler_mode == "nvcc":
|
||||
compiler.nvcc()
|
||||
elif compiler_mode == "nvrtc":
|
||||
compiler.nvrtc()
|
||||
else:
|
||||
raise Exception(f"Unexpected compiler string {compiler_mode}")
|
||||
|
||||
op_list = [operation]
|
||||
if operation.arch < 90:
|
||||
# Split K via Python is currently only supported for pre-SM90 kernels
|
||||
op_list.append(self.reduction_operation)
|
||||
|
||||
compiler.add_module(op_list)
|
||||
compiler.add_module(op_list, bypass_cache=True)
|
||||
|
||||
self.operation = operation
|
||||
|
||||
@@ -603,7 +610,7 @@ class GemmUniversalLauncher:
|
||||
return passed
|
||||
|
||||
|
||||
def test_all_gemm(operation: "GemmOperationUniversal", testcase="universal"):
|
||||
def test_all_gemm(operation: "GemmOperationUniversal", testcase="universal", compilation_mode="nvcc"):
|
||||
passed = True
|
||||
|
||||
minimum_operand_element_size = min(
|
||||
@@ -711,7 +718,7 @@ def test_all_gemm(operation: "GemmOperationUniversal", testcase="universal"):
|
||||
problem_alpha = [1.0]
|
||||
problem_beta = [2.0]
|
||||
|
||||
testbed = GemmUniversalLauncher(operation, interleaved=(testcase == "interleaved"))
|
||||
testbed = GemmUniversalLauncher(operation, interleaved=(testcase == "interleaved"), compiler_mode=compilation_mode)
|
||||
|
||||
for mode in modes:
|
||||
for m in problem_size_m:
|
||||
|
||||
@@ -30,10 +30,13 @@
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
import cutlass
|
||||
import cutlass_bindings
|
||||
|
||||
from cutlass import KernelScheduleSuffixes
|
||||
from cutlass import EpilogueScheduleSuffixes, KernelScheduleSuffixes
|
||||
from cutlass.utils.datatypes import binding_opclass, binding_type
|
||||
from cutlass.backend import library
|
||||
from cutlass.backend.test.gemm_testbed import test_all_gemm
|
||||
from cutlass.backend.utils.software import SubstituteTemplate
|
||||
|
||||
|
||||
@@ -75,6 +78,7 @@ def get_name(
|
||||
arch,
|
||||
opclass,
|
||||
kernel_schedule=None,
|
||||
epilogue_schedule=None,
|
||||
suffix="",
|
||||
):
|
||||
"""
|
||||
@@ -97,24 +101,26 @@ def get_name(
|
||||
:type opclass: cutlass_bindings.OpClass
|
||||
:param kernel_schedule: kernel_schedule type
|
||||
:type kernel_schedule: cutlass.KernelScheduleType
|
||||
:param epilogue_schedule: epilogue_schedule type
|
||||
:type epilogue_schedule: cutlass.EpilogueScheduleType
|
||||
:param suffix: additional string to add to the suffix of the name
|
||||
:type suffix: str
|
||||
|
||||
:return: str
|
||||
"""
|
||||
name_format = "test_SM${arch}_Device_Gemm_${eA}${lA}_${eB}${lB}_${eC}${lC}_${opclass}_${acc}_${tbM}x${tbN}x${tbK}_${cM}x${cN}x${cK}_${stages}_align${aA}-${aB}-${aC}${k}${suffix}"
|
||||
name_format = "test_SM${arch}_Device_Gemm_${eA}${lA}_${eB}${lB}_${eC}${lC}_${opclass}_${acc}_${tbM}x${tbN}x${tbK}_${cM}x${cN}x${cK}_${stages}_align${aA}-${aB}-${aC}${k}${e}${suffix}"
|
||||
return SubstituteTemplate(
|
||||
name_format,
|
||||
{
|
||||
"arch": str(arch),
|
||||
"eA": library.DataTypeNames[element_a],
|
||||
"eB": library.DataTypeNames[element_b],
|
||||
"eC": library.DataTypeNames[element_output],
|
||||
"eA": library.DataTypeNames[binding_type(element_a)],
|
||||
"eB": library.DataTypeNames[binding_type(element_b)],
|
||||
"eC": library.DataTypeNames[binding_type(element_output)],
|
||||
"lA": library.ShortLayoutTypeNames[layouts[0]],
|
||||
"lB": library.ShortLayoutTypeNames[layouts[1]],
|
||||
"lC": library.ShortLayoutTypeNames[layouts[2]],
|
||||
"opclass": library.OpcodeClassNames[opclass],
|
||||
"acc": library.DataTypeNames[element_accumulator],
|
||||
"opclass": library.OpcodeClassNames[binding_opclass(opclass)],
|
||||
"acc": library.DataTypeNames[binding_type(element_accumulator)],
|
||||
"cM": str(cluster_shape[0]),
|
||||
"cN": str(cluster_shape[1]),
|
||||
"cK": str(cluster_shape[2]),
|
||||
@@ -126,6 +132,174 @@ def get_name(
|
||||
"aB": str(alignments[1]),
|
||||
"aC": str(alignments[2]),
|
||||
"k": "" if kernel_schedule is None else KernelScheduleSuffixes[kernel_schedule],
|
||||
"e": "" if epilogue_schedule is None else EpilogueScheduleSuffixes[epilogue_schedule],
|
||||
"suffix": "" if suffix is None else suffix,
|
||||
},
|
||||
)
|
||||
|
||||
def get_name_conv2d(
|
||||
arch,
|
||||
conv_kind,
|
||||
element,
|
||||
element_accumulator,
|
||||
element_output,
|
||||
opclass,
|
||||
threadblock_shape,
|
||||
warp_count,
|
||||
instruction_shape,
|
||||
stages,
|
||||
iterator_algorithm,
|
||||
swizzle,
|
||||
split_k_mode,
|
||||
split_k_slices,
|
||||
activation
|
||||
):
|
||||
"""
|
||||
Generates a procedural name for a test case for conv2d
|
||||
|
||||
:param arch: compute capability of kernel being generated
|
||||
:type arch: int
|
||||
:param conv_kind: the convolution type (i.e. fprop, dgrad, wgrad)
|
||||
:type conv_kind: str
|
||||
:param iterator_algorithm: the iterator algorithm applied
|
||||
:type iterator_algorithm: cutlass_bindings.conv.IteratorAlgorithm
|
||||
:param element_a: data type of operand A
|
||||
:param element_b: data type of operand B
|
||||
:param element_c: data type of operand C
|
||||
:param element_accumulator: data type used in accumulation
|
||||
:param opclass: class of operation being performed (e.g., SIMT, Tensor Core)
|
||||
:type opclass: cutlass_bindings.OpClass
|
||||
:param threadblock_shape: indexable container of dimensions of threadblock tiles
|
||||
:param stages: number of pipeline stages to use in the kernel
|
||||
:type stages: int
|
||||
:param stride_support: stride support of dgrad
|
||||
:param alignment: int
|
||||
:type alignment: int
|
||||
|
||||
:return: str
|
||||
"""
|
||||
if iterator_algorithm is None:
|
||||
iterator_algorithm = "AUTO"
|
||||
if swizzle is None:
|
||||
swizzle = 1
|
||||
name_format = "test_SM${arch}_Device_Conv2d_${conv_kind}_${iter_alg}_ImplicitGemm_${eA}nhwc_${eB}nhwc_${eC}nhwc_${opclass}_${acc}_${tbM}x${tbN}x${tbK}_${wM}x${wN}x${wK}_${IM}${IN}${IK}_stage${stages}_swizzle${swizzle}_${split_k_mode}${split_k_slices}_${activation}"
|
||||
|
||||
return SubstituteTemplate(
|
||||
name_format,
|
||||
{
|
||||
"arch": str(arch),
|
||||
"conv_kind": conv_kind,
|
||||
"iter_alg": iterator_algorithm,
|
||||
"eA": library.DataTypeNames[binding_type(element)],
|
||||
"eB": library.DataTypeNames[binding_type(element)],
|
||||
"eC": library.DataTypeNames[binding_type(element_output)],
|
||||
"opclass": opclass,
|
||||
"acc": library.DataTypeNames[binding_type(element_accumulator)],
|
||||
"tbM": str(threadblock_shape[0]),
|
||||
"tbN": str(threadblock_shape[1]),
|
||||
"tbK": str(threadblock_shape[2]),
|
||||
"wM": str(threadblock_shape[0] // warp_count[0]),
|
||||
"wN": str(threadblock_shape[1] // warp_count[1]),
|
||||
"wK": str(threadblock_shape[2] // warp_count[2]),
|
||||
"IM": str(instruction_shape[0]),
|
||||
"IN": str(instruction_shape[1]),
|
||||
"IK": str(instruction_shape[2]),
|
||||
"stages": str(stages),
|
||||
"swizzle": str(swizzle),
|
||||
"split_k_mode": split_k_mode,
|
||||
"split_k_slices": str(split_k_slices),
|
||||
"activation": activation
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def add_test_gemm(
|
||||
cls=None,
|
||||
cc=None,
|
||||
element=None,
|
||||
layouts=None,
|
||||
alignments=None,
|
||||
element_output=None,
|
||||
element_accumulator=None,
|
||||
cluster_shape=None,
|
||||
threadblock_shape=None,
|
||||
warp_count=None,
|
||||
stages=None,
|
||||
opclass=None,
|
||||
swizzle=None,
|
||||
kernel_schedule=None,
|
||||
epilogue_schedule=None,
|
||||
compilation_modes=['nvcc', 'nvrtc']):
|
||||
"""
|
||||
Create test-running functions with the given specification and set it as a method of ``cls``.
|
||||
|
||||
:param cls: class to which the generated method will be added
|
||||
:type cls: type
|
||||
:param cc: compute capability to compile for
|
||||
:type cc: int
|
||||
:param element: data type of A and B operands
|
||||
:type element: cutlass.DataType.f16
|
||||
:param layouts: layouts of A, B, and C operands
|
||||
:type layouts: list or tuple
|
||||
:param alignments: alingments of A, B, and C operands
|
||||
:type alignments: list or tuple
|
||||
:param element_output: data type of the output element
|
||||
:type element_output: cutlass.DataType
|
||||
:param element_accumulator: data type used in accumulation
|
||||
:type element_accumulator: cutlass.DataType
|
||||
:param cluster_shape: dimensions of clusters
|
||||
:type cluster_shape: list or tuple
|
||||
:param threadblock_shape: dimensions of threadblock tiles
|
||||
:type threadblock_shape: list or tuple
|
||||
:param warp_count: warps to be launched per threadblock dimension
|
||||
:type warp_count: list or tuple
|
||||
:param stages: number of pipeline stages to use in the kernel
|
||||
:type stages: int
|
||||
:param opclass: class of operation being performed (e.g., SIMT, Tensor Core)
|
||||
:type opclass: cutlass.OpClass
|
||||
:param swizzle: threadblock swizzling functor
|
||||
:param kernel_schedule: kernel schedule to use
|
||||
:type kernel_schedule: cutlass.KernelScheduleType
|
||||
:param epilogue_schedule: epilogue schedule to use
|
||||
:type epilogue_schedule: cutlass.EpilogueScheduleType
|
||||
:param compilation_modes: list of compilers to used in testing the kernel (options: 'nvrtc', 'nvcc')
|
||||
:type compilation_modes: list
|
||||
"""
|
||||
|
||||
for compilation_mode in compilation_modes:
|
||||
def run(self):
|
||||
"""
|
||||
Dynamically-generated function that constructs a GEMM operation and verifies it against
|
||||
multiple test cases.
|
||||
"""
|
||||
element_A = element
|
||||
element_B = element
|
||||
layout_A, layout_B, layout_C = layouts
|
||||
alignment_A, alignment_B, alignment_C = alignments
|
||||
|
||||
plan = cutlass.op.Gemm(element_A=element_A, element_B=element_B,
|
||||
element_C=element_output, element_D=element_output,
|
||||
layout_A=layout_A, layout_B=layout_B, layout_C=layout_C,
|
||||
element_accumulator=element_accumulator,
|
||||
kernel_cc=cc)
|
||||
|
||||
plan.opclass = opclass
|
||||
if swizzle is not None:
|
||||
plan.swizzling_functor = swizzle
|
||||
td = plan.tile_descriptions()[0]
|
||||
td.threadblock_shape = threadblock_shape
|
||||
td.stages = stages
|
||||
if warp_count is not None:
|
||||
td.warp_count = warp_count
|
||||
td.cluster_shape = cluster_shape
|
||||
op = plan.construct(tile_description=td, alignment_A=alignment_A, alignment_B=alignment_B, alignment_C=alignment_C)
|
||||
self.assertTrue(test_all_gemm(op, 'universal', compilation_mode=compilation_mode))
|
||||
|
||||
element_epilogue = element_accumulator
|
||||
name = get_name(
|
||||
layouts=layouts, alignments=alignments, element_output=element_output, element_accumulator=element_accumulator,
|
||||
element_epilogue=element_epilogue, cluster_shape=cluster_shape, threadblock_shape=threadblock_shape,
|
||||
stages=stages, element_a=element, element_b=element, arch=cc, opclass=opclass,
|
||||
kernel_schedule=kernel_schedule, epilogue_schedule=epilogue_schedule, suffix=f'_{compilation_mode}')
|
||||
|
||||
setattr(cls, name, run)
|
||||
|
||||
@@ -135,9 +135,8 @@ void bind_dgrad_swizzle(py::module & m, std::string name) {
|
||||
:param problem_size: Implicit gemm problem size conv_operator(NPQK, NHWC, KRSC)
|
||||
:type problem_size: :class:`cutlass.gemm.GemmCoord`)
|
||||
)pbdoc")
|
||||
.def("get_grid_shape", [](const T & swizzle, cutlass::gemm::GemmCoord tiled_shape) {
|
||||
return dim3(tiled_shape.m(), tiled_shape.n(), tiled_shape.k());
|
||||
}, py::arg("tiled_shape"),
|
||||
.def("get_grid_shape", &T::get_grid_shape,
|
||||
py::arg("tiled_shape"),
|
||||
R"pbdoc(Computes CUDA grid dimensions given a size in units of logical tiles)pbdoc")
|
||||
.def("tag", [](const T & swizzle){
|
||||
return demangle(typeid(T).name());
|
||||
|
||||
@@ -155,7 +155,8 @@ void bind_conv_host_references(py::module &m) {
|
||||
/// Cache
|
||||
py::class_<test::conv::device::CachedTestKey>(m, "CachedTestKey")
|
||||
.def(py::init<>())
|
||||
.def(py::init<std::string, std::string, std::string, uint32_t, uint32_t, uint32_t>());
|
||||
.def(py::init<std::string, std::string, std::string, uint32_t, uint32_t, uint32_t>())
|
||||
.def_readwrite("problem", &test::conv::device::CachedTestKey::problem);
|
||||
|
||||
py::class_<test::conv::device::CachedTestResult>(m, "CachedTestResult")
|
||||
.def(py::init<>())
|
||||
|
||||
@@ -117,16 +117,16 @@ cutlass::Status ${name}_kernel_run(
|
||||
|
||||
typename DeviceKernel::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{M, N, K, L}, // problem size
|
||||
A, // ptrA
|
||||
make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)), // stride A
|
||||
B, // ptrB
|
||||
make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)), // stride B
|
||||
{M, N, K, L}, // problem size
|
||||
A, // ptrA
|
||||
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)), // stride A
|
||||
B, // ptrB
|
||||
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)), // stride B
|
||||
{
|
||||
C, // ptrC
|
||||
make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)), // stride C
|
||||
D, // ptrD
|
||||
make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)), // stride D
|
||||
C, // ptrC
|
||||
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)), // stride C
|
||||
D, // ptrD
|
||||
cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)), // stride D
|
||||
{alpha, beta},
|
||||
},
|
||||
hw_info
|
||||
@@ -180,3 +180,86 @@ cutlass::Status ${name}_kernel_run(int problem_count, cutlass::gemm::GemmCoord*
|
||||
return status;
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
_CUTLASS_KERNEL_RUN_CONV2D_2x = """
|
||||
|
||||
using UnderlyingKernel = typename DeviceKernel::UnderlyingKernel;
|
||||
namespace {
|
||||
using TensorRefA = typename UnderlyingKernel::TensorRefA;
|
||||
using TensorRefB = typename UnderlyingKernel::TensorRefB;
|
||||
using TensorRefC = typename UnderlyingKernel::TensorRefC;
|
||||
using ElementCompute = typename UnderlyingKernel::EpilogueOutputOp::ElementCompute;
|
||||
}
|
||||
|
||||
template<typename TensorRef, typename Element>
|
||||
TensorRef get_tensor_ref(cutlass::Tensor4DCoord tensor_coord, Element* ptr){
|
||||
cutlass::layout::TensorNHWC layout = cutlass::layout::TensorNHWC::packed(tensor_coord);
|
||||
TensorRef tensor_ref(ptr, layout);
|
||||
return tensor_ref;
|
||||
}
|
||||
|
||||
cutlass::Status ${name}_kernel_run(cutlass::conv::Conv2dProblemSize* problem_size,
|
||||
UnderlyingKernel::ElementA* A, UnderlyingKernel::ElementB* B,
|
||||
UnderlyingKernel::ElementC* C, UnderlyingKernel::ElementC* D,
|
||||
ElementCompute alpha, ElementCompute beta, std::string split_k_mode,
|
||||
cudaStream_t stream, int device_id=0) {
|
||||
// create the tensor references
|
||||
cutlass::Tensor4DCoord tensor_coord_A = cutlass::conv::implicit_gemm_tensor_a_extent(
|
||||
cutlass::conv::Operator::k${conv_kind_name}, *problem_size
|
||||
);
|
||||
cutlass::Tensor4DCoord tensor_coord_B = cutlass::conv::implicit_gemm_tensor_b_extent(
|
||||
cutlass::conv::Operator::k${conv_kind_name}, *problem_size
|
||||
);
|
||||
cutlass::Tensor4DCoord tensor_coord_C = cutlass::conv::implicit_gemm_tensor_c_extent(
|
||||
cutlass::conv::Operator::k${conv_kind_name}, *problem_size
|
||||
);
|
||||
|
||||
TensorRefA tensor_ref_A = get_tensor_ref<TensorRefA, UnderlyingKernel::ElementA>(tensor_coord_A, A);
|
||||
TensorRefB tensor_ref_B = get_tensor_ref<TensorRefB, UnderlyingKernel::ElementB>(tensor_coord_B, B);
|
||||
TensorRefC tensor_ref_C = get_tensor_ref<TensorRefC, UnderlyingKernel::ElementC>(tensor_coord_C, C);
|
||||
TensorRefC tensor_ref_D = get_tensor_ref<TensorRefC, UnderlyingKernel::ElementC>(tensor_coord_C, D);
|
||||
|
||||
cutlass::conv::SplitKMode mode;
|
||||
if (split_k_mode == "serial") {
|
||||
mode = cutlass::conv::SplitKMode::kSerial;
|
||||
} else if (split_k_mode == "parallel") {
|
||||
mode = cutlass::conv::SplitKMode::kParallel;
|
||||
} else {
|
||||
throw std::runtime_error("Invalid split_k_mode: " + split_k_mode);
|
||||
}
|
||||
|
||||
typename DeviceKernel::Arguments arguments{
|
||||
*problem_size,
|
||||
tensor_ref_A,
|
||||
tensor_ref_B,
|
||||
tensor_ref_C,
|
||||
tensor_ref_D,
|
||||
{alpha, beta},
|
||||
mode
|
||||
};
|
||||
|
||||
DeviceKernel implicit_gemm_op;
|
||||
|
||||
size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments);
|
||||
|
||||
void* workspace_ptr = device_memory_allocation(workspace_size, device_id);
|
||||
|
||||
cutlass::Status status = implicit_gemm_op.can_implement(arguments);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
status = implicit_gemm_op.initialize(arguments, workspace_ptr, stream);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
//
|
||||
// Launch initialized CUTLASS kernel
|
||||
//
|
||||
status = implicit_gemm_op(stream);
|
||||
|
||||
return status;
|
||||
}
|
||||
"""
|
||||
|
||||
@@ -85,7 +85,8 @@ import cutlass_bindings
|
||||
|
||||
from cutlass import CUTLASS_PATH, logger, swizzle
|
||||
from cutlass.backend.gemm_operation import GemmOperationGrouped, GemmOperationUniversal
|
||||
from cutlass.backend.library import ApiVersion
|
||||
from cutlass.backend.conv2d_operation import Conv2dOperation
|
||||
from cutlass.backend.library import ApiVersion, ConvKindNames
|
||||
from cutlass.backend.utils.software import CheckPackages, SubstituteTemplate
|
||||
from cutlass.emit import common
|
||||
|
||||
@@ -95,12 +96,26 @@ if torch_available:
|
||||
|
||||
|
||||
_PYTORCH_CUDA_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """
|
||||
#include <cuda_runtime.h>
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/util/device_memory.h"
|
||||
|
||||
// helper function allocating the memory
|
||||
void* device_memory_allocation(size_t size, int device_id=0) {
|
||||
if (size > 0) {
|
||||
torch::Device device(torch::kCUDA, device_id);
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
torch::TensorOptions options = torch::TensorOptions().dtype(torch::kI8).device(device);
|
||||
at::Tensor device_tensor = torch::empty({(long)size,}, options);
|
||||
return reinterpret_cast<void*>(device_tensor.data_ptr());
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
${includes}
|
||||
${declaration}
|
||||
${impl}
|
||||
@@ -143,6 +158,72 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
}
|
||||
"""
|
||||
|
||||
_PYTORCH_CONV2D_FPROP_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
// CUDA forward declarations
|
||||
at::Tensor ${name}_kernel(
|
||||
const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
|
||||
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
|
||||
float alpha=1.f, float beta=0.f,
|
||||
std::string split_k_mode="serial", int split_k_slices=1);
|
||||
|
||||
// C++ interface
|
||||
at::Tensor ${name}(
|
||||
const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
|
||||
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
|
||||
float alpha=1.f, float beta=0.f,
|
||||
std::string split_k_mode="serial", int split_k_slices=1) {
|
||||
return ${name}_kernel(A, B, C, stride, padding, dilation, alpha, beta, split_k_mode, split_k_slices);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("run",
|
||||
py::overload_cast<
|
||||
const at::Tensor&, const at::Tensor&, at::optional<const at::Tensor>,
|
||||
std::tuple<int, int>, std::tuple<int, int>, std::tuple<int, int>, float, float, std::string, int>(
|
||||
&${name}), py::arg("A"), py::arg("B"), py::arg("C") = nullptr,
|
||||
py::arg("stride") = std::make_tuple(1, 1), py::arg("padding") = std::make_tuple(1, 1), py::arg("dilation") = std::make_tuple(1, 1),
|
||||
py::arg("alpha") = 1.f, py::arg("beta") = 0.f,
|
||||
py::arg("split_k_mode") = "serial", py::arg("split_k_slices") = 1);
|
||||
}
|
||||
"""
|
||||
|
||||
_PYTORCH_CONV2D_GRAD_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
// CUDA forward declarations
|
||||
at::Tensor ${name}_kernel(
|
||||
std::tuple<int, int, int, int> result_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
|
||||
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
|
||||
float alpha=1.f, float beta=0.f,
|
||||
std::string split_k_mode="serial", int split_k_slices=1);
|
||||
|
||||
// C++ interface
|
||||
at::Tensor ${name}(
|
||||
std::tuple<int, int, int, int> result_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
|
||||
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
|
||||
float alpha=1.f, float beta=0.f,
|
||||
std::string split_k_mode="serial", int split_k_slices=1) {
|
||||
return ${name}_kernel(result_size, A, B, C, stride, padding, dilation, alpha, beta, split_k_mode, split_k_slices);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("run",
|
||||
py::overload_cast<
|
||||
std::tuple<int, int, int, int>, const at::Tensor&, const at::Tensor&, at::optional<const at::Tensor>,
|
||||
std::tuple<int, int>, std::tuple<int, int>, std::tuple<int, int>, float, float, std::string, int>(
|
||||
&${name}), py::arg("result_size"), py::arg("A"), py::arg("B"), py::arg("C") = nullptr,
|
||||
py::arg("stride") = std::make_tuple(1, 1), py::arg("padding") = std::make_tuple(1, 1), py::arg("dilation") = std::make_tuple(1, 1),
|
||||
py::arg("alpha") = 1.f, py::arg("beta") = 0.f,
|
||||
py::arg("split_k_mode") = "serial", py::arg("split_k_slices") = 1);
|
||||
}
|
||||
"""
|
||||
|
||||
_PYTORCH_GEMM_INCLUDES = {
|
||||
ApiVersion.v2x: """
|
||||
#include "cutlass/gemm/device/gemm_universal.h"
|
||||
@@ -162,6 +243,13 @@ _PYTORCH_GROUPED_GEMM_INCLUDES = """
|
||||
#include "cutlass/gemm/device/gemm_grouped.h"
|
||||
"""
|
||||
|
||||
_PYTORCH_CONV2D_INCLUDES = """
|
||||
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
|
||||
#include "cutlass/conv/kernel/default_conv2d_dgrad.h"
|
||||
#include "cutlass/conv/kernel/default_conv2d_wgrad.h"
|
||||
#include "cutlass/conv/device/implicit_gemm_convolution.h"
|
||||
"""
|
||||
|
||||
_CUTLASS_TYPE_TO_TORCH_TYPE = {
|
||||
cutlass_bindings.float16: "torch::kF16",
|
||||
cutlass_bindings.float32: "torch::kF32",
|
||||
@@ -356,6 +444,133 @@ std::vector<at::Tensor> ${name}_kernel(const std::vector<at::Tensor>& A, const s
|
||||
"""
|
||||
)
|
||||
|
||||
_PYTORCH_CONV2D_IMPL_TEMPLATE_2x = """
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
cutlass::Status status = ${name}_kernel_run(
|
||||
&problem_size,
|
||||
reinterpret_cast<typename UnderlyingKernel::ElementA*>(A.data_ptr()),
|
||||
reinterpret_cast<typename UnderlyingKernel::ElementB*>(B.data_ptr()),
|
||||
ptrC,
|
||||
reinterpret_cast<typename UnderlyingKernel::ElementC*>(D.data_ptr()),
|
||||
alpha, beta,
|
||||
split_k_mode, stream, B.device().index());
|
||||
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");
|
||||
return D;
|
||||
}
|
||||
"""
|
||||
|
||||
_PYTORCH_CONV2D_FPROP_IMPL_TEMPLATE_2x = (
|
||||
common._CUTLASS_KERNEL_RUN_CONV2D_2x
|
||||
+ """
|
||||
at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
|
||||
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
|
||||
float alpha=1.f, float beta=0.f, std::string split_k_mode="serial", int split_k_slices=1) {
|
||||
int N, H, W, C_, K, R, S, P, Q;
|
||||
N = A.size(0);
|
||||
C_ = A.size(1);
|
||||
H = A.size(2);
|
||||
W = A.size(3);
|
||||
|
||||
K = B.size(0);
|
||||
R = B.size(2);
|
||||
S = B.size(3);
|
||||
|
||||
cutlass::conv::Conv2dProblemSize problem_size(
|
||||
cutlass::Tensor4DCoord(N, H, W, C_),
|
||||
cutlass::Tensor4DCoord(K, R, S, C_),
|
||||
cutlass::Tensor4DCoord(std::get<0>(padding), std::get<0>(padding), std::get<1>(padding), std::get<1>(padding)),
|
||||
cutlass::MatrixCoord(std::get<0>(stride), std::get<1>(stride)),
|
||||
cutlass::MatrixCoord(std::get<0>(dilation), std::get<1>(dilation)),
|
||||
cutlass::conv::Mode::kCrossCorrelation,
|
||||
split_k_slices
|
||||
);
|
||||
|
||||
P = problem_size.P;
|
||||
Q = problem_size.Q;
|
||||
|
||||
typename UnderlyingKernel::ElementC* ptrC = (C == at::nullopt) ?
|
||||
nullptr :
|
||||
reinterpret_cast<typename UnderlyingKernel::ElementC*>(C->data_ptr());
|
||||
|
||||
torch::TensorOptions options = torch::TensorOptions().dtype(${torch_type_C}).device(B.device()).memory_format(at::MemoryFormat::ChannelsLast);
|
||||
at::Tensor D = torch::zeros({N, K, P, Q}, options);
|
||||
""" + _PYTORCH_CONV2D_IMPL_TEMPLATE_2x
|
||||
)
|
||||
|
||||
|
||||
_PYTORCH_CONV2D_DGRAD_IMPL_TEMPLATE_2x = (
|
||||
common._CUTLASS_KERNEL_RUN_CONV2D_2x
|
||||
+ """
|
||||
at::Tensor ${name}_kernel(std::tuple<int, int, int, int> input_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
|
||||
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1}, float alpha=1.f, float beta=0.f,
|
||||
std::string split_k_mode="serial", int split_k_slices=1) {
|
||||
int N, H, W, C_, K, R, S;
|
||||
N = std::get<0>(input_size);
|
||||
C_ = std::get<1>(input_size);
|
||||
H = std::get<2>(input_size);
|
||||
W = std::get<3>(input_size);
|
||||
|
||||
K = B.size(0);
|
||||
R = B.size(2);
|
||||
S = B.size(3);
|
||||
|
||||
cutlass::conv::Conv2dProblemSize problem_size(
|
||||
cutlass::Tensor4DCoord(N, H, W, C_),
|
||||
cutlass::Tensor4DCoord(K, R, S, C_),
|
||||
cutlass::Tensor4DCoord(std::get<0>(padding), std::get<0>(padding), std::get<1>(padding), std::get<1>(padding)),
|
||||
cutlass::MatrixCoord(std::get<0>(stride), std::get<1>(stride)),
|
||||
cutlass::MatrixCoord(std::get<0>(dilation), std::get<1>(dilation)),
|
||||
cutlass::conv::Mode::kCrossCorrelation,
|
||||
split_k_slices
|
||||
);
|
||||
|
||||
typename UnderlyingKernel::ElementC* ptrC = (C == at::nullopt) ?
|
||||
nullptr :
|
||||
reinterpret_cast<typename UnderlyingKernel::ElementC*>(C->data_ptr());
|
||||
|
||||
torch::TensorOptions options = torch::TensorOptions().dtype(${torch_type_C}).device(B.device()).memory_format(at::MemoryFormat::ChannelsLast);
|
||||
at::Tensor D = torch::empty({N, C_, H, W}, options);
|
||||
""" + _PYTORCH_CONV2D_IMPL_TEMPLATE_2x
|
||||
)
|
||||
|
||||
|
||||
_PYTORCH_CONV2D_WGRAD_IMPL_TEMPLATE_2x = (
|
||||
common._CUTLASS_KERNEL_RUN_CONV2D_2x
|
||||
+ """
|
||||
at::Tensor ${name}_kernel(std::tuple<int, int, int, int> weight_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
|
||||
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1}, float alpha=1.f, float beta=0.f,
|
||||
std::string split_k_mode="serial", int split_k_slices=1) {
|
||||
int N, H, W, C_, K, R, S;
|
||||
K = std::get<0>(weight_size);
|
||||
C_ = std::get<1>(weight_size);
|
||||
R = std::get<2>(weight_size);
|
||||
S = std::get<3>(weight_size);
|
||||
|
||||
N = B.size(0);
|
||||
H = B.size(2);
|
||||
W = B.size(3);
|
||||
|
||||
cutlass::conv::Conv2dProblemSize problem_size(
|
||||
cutlass::Tensor4DCoord(N, H, W, C_),
|
||||
cutlass::Tensor4DCoord(K, R, S, C_),
|
||||
cutlass::Tensor4DCoord(std::get<0>(padding), std::get<0>(padding), std::get<1>(padding), std::get<1>(padding)),
|
||||
cutlass::MatrixCoord(std::get<0>(stride), std::get<1>(stride)),
|
||||
cutlass::MatrixCoord(std::get<0>(dilation), std::get<1>(dilation)),
|
||||
cutlass::conv::Mode::kCrossCorrelation,
|
||||
split_k_slices
|
||||
);
|
||||
|
||||
typename UnderlyingKernel::ElementC* ptrC = (C == at::nullopt) ?
|
||||
nullptr :
|
||||
reinterpret_cast<typename UnderlyingKernel::ElementC*>(C->data_ptr());
|
||||
|
||||
torch::TensorOptions options = torch::TensorOptions().dtype(${torch_type_C}).device(B.device()).memory_format(at::MemoryFormat::ChannelsLast);
|
||||
at::Tensor D = torch::empty({K, C_, R, S}, options);
|
||||
""" + _PYTORCH_CONV2D_IMPL_TEMPLATE_2x
|
||||
)
|
||||
|
||||
|
||||
_PYTORCH_SETUP_PY = common._PYSTYLE_AUTOGEN_COMMENT + """
|
||||
from setuptools import setup
|
||||
@@ -607,6 +822,73 @@ def _pytorch_grouped_gemm(
|
||||
return None
|
||||
|
||||
|
||||
def _pytorch_conv2d(op, name: str, cc: int, jit: bool = False, sourcedir: str = ""):
|
||||
"""
|
||||
Generates source for building a PyTorch CUDA module that leverages the CUTLASS Conv2d
|
||||
specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time
|
||||
compiled, loaded, and returned.
|
||||
|
||||
:param op: operation to emit in the module
|
||||
:param name: name of the module to generate
|
||||
:type name: str
|
||||
:param cc: compute capability of the device the module should target
|
||||
:type cc: int
|
||||
:param jit: whether the module should be just-in-time compiled
|
||||
:type jit: bool
|
||||
:param sourcedir: directory to which generated source files should be written
|
||||
:type sourcedir: str
|
||||
|
||||
Note that the when conv kind is `dgrad` or `wgrad`, the size of the input `(N, C, H, W)` or
|
||||
weight `(K, C, R, S)` should be provided. This is because there are multiple valid solutions
|
||||
for H/W/R/S given the same P/Q.
|
||||
|
||||
:return: loaded PyTorch module if ``jit=True`` or ``None`` otherwise
|
||||
"""
|
||||
if sourcedir != "" and not os.path.isdir(sourcedir):
|
||||
os.makedirs(sourcedir)
|
||||
cuda_file = os.path.join(sourcedir, name + "_kernel.cu")
|
||||
extra_kw = {}
|
||||
if op.conv_kind == cutlass_bindings.conv.Operator.fprop:
|
||||
impl_template = _PYTORCH_CONV2D_FPROP_IMPL_TEMPLATE_2x
|
||||
cpp_template = _PYTORCH_CONV2D_FPROP_CPP_TEMPLATE
|
||||
elif op.conv_kind == cutlass_bindings.conv.Operator.dgrad:
|
||||
impl_template = _PYTORCH_CONV2D_DGRAD_IMPL_TEMPLATE_2x
|
||||
cpp_template = _PYTORCH_CONV2D_GRAD_CPP_TEMPLATE
|
||||
elif op.conv_kind == cutlass_bindings.conv.Operator.wgrad:
|
||||
impl_template = _PYTORCH_CONV2D_WGRAD_IMPL_TEMPLATE_2x
|
||||
cpp_template = _PYTORCH_CONV2D_GRAD_CPP_TEMPLATE
|
||||
extra_kw["conv_kind_name"] = ConvKindNames[op.conv_kind].capitalize()
|
||||
extra_kw["torch_type_C"] = _CUTLASS_TYPE_TO_TORCH_TYPE[op.C.element]
|
||||
cuda_impl = SubstituteTemplate(impl_template, {"name": name, **extra_kw})
|
||||
cuda_source = SubstituteTemplate(
|
||||
_PYTORCH_CUDA_TEMPLATE,
|
||||
{
|
||||
"includes": _PYTORCH_CONV2D_INCLUDES,
|
||||
"declaration": op.rt_module.emit(),
|
||||
"procedural_name": op.procedural_name(),
|
||||
"impl": cuda_impl,
|
||||
"torch_type_C": _CUTLASS_TYPE_TO_TORCH_TYPE[op.C.element],
|
||||
},
|
||||
)
|
||||
with open(cuda_file, "w") as outfile:
|
||||
outfile.write(cuda_source)
|
||||
|
||||
cpp_file = os.path.join(sourcedir, name + ".cpp")
|
||||
cpp_source = SubstituteTemplate(
|
||||
cpp_template,
|
||||
{"name": name, "description": f"CUTLASS {op.procedural_name()} Conv2d"},
|
||||
)
|
||||
with open(cpp_file, "w") as outfile:
|
||||
outfile.write(cpp_source)
|
||||
|
||||
_generate_setup(name, sourcedir)
|
||||
|
||||
if jit:
|
||||
return _jit(name, cc, cpp_file, cuda_file)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def pytorch(op, name: str, cc: int, jit: bool = False, sourcedir: str = ""):
|
||||
"""
|
||||
Generates source for building a PyTorch CUDA module that leverages the CUTLASS kernel
|
||||
@@ -633,6 +915,8 @@ def pytorch(op, name: str, cc: int, jit: bool = False, sourcedir: str = ""):
|
||||
return _pytorch_gemm(device_op, name, cc, jit, sourcedir)
|
||||
elif isinstance(op, GemmOperationGrouped):
|
||||
return _pytorch_grouped_gemm(device_op, name, cc, jit, sourcedir)
|
||||
elif isinstance(op, Conv2dOperation):
|
||||
return _pytorch_conv2d(device_op, name, cc, jit, sourcedir)
|
||||
else:
|
||||
raise Exception(
|
||||
f"Operation type {type(op)} is not currently supported for PyTorch emission."
|
||||
|
||||
@@ -43,6 +43,9 @@ _cuda_version = __version__.split("rc")[0]
|
||||
# Imports from CUTLASS profiler generator and manifest scripts
|
||||
import generator as prof_generator
|
||||
import manifest as prof_manifest
|
||||
from library import (
|
||||
ConvKind, IteratorAlgorithm, StrideSupport, GroupMode
|
||||
)
|
||||
|
||||
import cutlass
|
||||
from cutlass.utils.check import valid_stage_count
|
||||
@@ -132,6 +135,8 @@ class KernelsForDataType:
|
||||
ld = shape[0]
|
||||
elif layout == cutlass.LayoutType.RowMajor:
|
||||
ld = shape[1]
|
||||
elif layout == cutlass.LayoutType.TensorNHWC:
|
||||
ld = shape[-1]
|
||||
else:
|
||||
raise Exception(f"Unexpected or unsupported layout {layout}")
|
||||
|
||||
@@ -222,8 +227,9 @@ class ArchOptions:
|
||||
# find available opclasses and data types
|
||||
for name, op_list in manifest.operations[operation_kind].items():
|
||||
for op in op_list:
|
||||
if op.gemm_kind not in gemm_kinds:
|
||||
continue
|
||||
if operation_kind == cutlass.OperationKind.Gemm:
|
||||
if op.gemm_kind not in gemm_kinds:
|
||||
continue
|
||||
|
||||
mi = op.tile_description.math_instruction
|
||||
if mi.math_operation not in self.allowed_math_operations:
|
||||
@@ -276,21 +282,36 @@ class ArchOptions:
|
||||
if cutlass.OpcodeClass.Simt not in self.operations_by_opclass:
|
||||
self.operations_by_opclass[cutlass.OpcodeClass.Simt] = {}
|
||||
|
||||
types = [
|
||||
(cutlass.DataType.s8, cutlass.DataType.s8, cutlass.DataType.s8),
|
||||
(cutlass.DataType.s8, cutlass.DataType.s8, cutlass.DataType.s32),
|
||||
(cutlass.DataType.f16, cutlass.DataType.f16, cutlass.DataType.f16),
|
||||
(cutlass.DataType.f16, cutlass.DataType.f16, cutlass.DataType.f32),
|
||||
(cutlass.DataType.f32, cutlass.DataType.f32, cutlass.DataType.f32),
|
||||
(cutlass.DataType.f64, cutlass.DataType.f64, cutlass.DataType.f64),
|
||||
]
|
||||
if operation_kind == cutlass.OperationKind.Gemm:
|
||||
types = [
|
||||
(cutlass.DataType.s8, cutlass.DataType.s8, cutlass.DataType.s8),
|
||||
(cutlass.DataType.s8, cutlass.DataType.s8, cutlass.DataType.s32),
|
||||
(cutlass.DataType.f16, cutlass.DataType.f16, cutlass.DataType.f16),
|
||||
(cutlass.DataType.f16, cutlass.DataType.f16, cutlass.DataType.f32),
|
||||
(cutlass.DataType.f32, cutlass.DataType.f32, cutlass.DataType.f32),
|
||||
(cutlass.DataType.f64, cutlass.DataType.f64, cutlass.DataType.f64),
|
||||
]
|
||||
|
||||
layouts = [
|
||||
(cutlass.LayoutType.RowMajor, cutlass.LayoutType.RowMajor),
|
||||
(cutlass.LayoutType.RowMajor, cutlass.LayoutType.ColumnMajor),
|
||||
(cutlass.LayoutType.ColumnMajor, cutlass.LayoutType.RowMajor),
|
||||
(cutlass.LayoutType.ColumnMajor, cutlass.LayoutType.ColumnMajor),
|
||||
]
|
||||
elif operation_kind == cutlass.OperationKind.Conv2d:
|
||||
types = [
|
||||
(cutlass.DataType.f16, cutlass.DataType.f16, cutlass.DataType.f16),
|
||||
(cutlass.DataType.f16, cutlass.DataType.f16, cutlass.DataType.f32),
|
||||
(cutlass.DataType.f32, cutlass.DataType.f32, cutlass.DataType.f32),
|
||||
(cutlass.DataType.f64, cutlass.DataType.f64, cutlass.DataType.f64),
|
||||
]
|
||||
|
||||
layouts = [
|
||||
(cutlass.LayoutType.TensorNHWC, cutlass.LayoutType.TensorNHWC),
|
||||
]
|
||||
else:
|
||||
raise NotImplementedError(f"Operation kind {operation_kind} is currently unsupported.")
|
||||
|
||||
layouts = [
|
||||
(cutlass.LayoutType.RowMajor, cutlass.LayoutType.RowMajor),
|
||||
(cutlass.LayoutType.RowMajor, cutlass.LayoutType.ColumnMajor),
|
||||
(cutlass.LayoutType.ColumnMajor, cutlass.LayoutType.RowMajor),
|
||||
(cutlass.LayoutType.ColumnMajor, cutlass.LayoutType.ColumnMajor),
|
||||
]
|
||||
alignment = 1
|
||||
epilogue_functor = cutlass.EpilogueFunctor.LinearCombination
|
||||
swizzling_functor = cutlass.SwizzlingFunctor.Identity8
|
||||
@@ -319,12 +340,22 @@ class ArchOptions:
|
||||
if not valid_stage_count(target_cc, td_from_profiler_td(td))[0]:
|
||||
continue
|
||||
|
||||
new_operation = prof_manifest.GemmOperation(
|
||||
cutlass.GemmKind.Universal, td.minimum_compute_capability,
|
||||
td, A, B, C, type_comb[2], epilogue_functor, swizzling_functor)
|
||||
|
||||
new_kernels = KernelsForDataType(type_comb, layout_comb)
|
||||
new_kernels.add(new_operation)
|
||||
|
||||
if operation_kind == cutlass.OperationKind.Gemm:
|
||||
new_operation = prof_manifest.GemmOperation(
|
||||
cutlass.GemmKind.Universal, td.minimum_compute_capability,
|
||||
td, A, B, C, type_comb[2], epilogue_functor, swizzling_functor)
|
||||
new_kernels.add(new_operation)
|
||||
elif operation_kind == cutlass.OperationKind.Conv2d:
|
||||
for conv_kind in [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad]:
|
||||
new_operation = prof_manifest.Conv2dOperation(
|
||||
conv_kind, IteratorAlgorithm.Analytic, td.minimum_compute_capability, td,
|
||||
A, B, C, type_comb[2], StrideSupport.Strided, epilogue_functor, swizzling_functor,
|
||||
group_mode=GroupMode.SingleGroup
|
||||
)
|
||||
new_kernels.add(new_operation)
|
||||
|
||||
self.operations_by_opclass[cutlass.OpcodeClass.Simt][comb] = new_kernels
|
||||
|
||||
# Sort all operations
|
||||
@@ -437,9 +468,12 @@ class OptionRegistry:
|
||||
self.registry = {}
|
||||
|
||||
gemm_kinds = [cutlass.GemmKind.Universal, cutlass.GemmKind.Universal3x]
|
||||
operation_kinds = [cutlass.OperationKind.Gemm, cutlass.OperationKind.Conv2d]
|
||||
# Construct options for each CC
|
||||
for kernel_cc in _generator_ccs:
|
||||
self.registry[kernel_cc] = ArchOptions(target_cc, kernel_cc, cutlass.OperationKind.Gemm, gemm_kinds)
|
||||
self.registry[kernel_cc] = {}
|
||||
for opkind in operation_kinds:
|
||||
self.registry[kernel_cc][opkind] = ArchOptions(target_cc, kernel_cc, opkind, gemm_kinds)
|
||||
|
||||
def options_for_cc(self, cc: int) -> ArchOptions:
|
||||
return self.registry.get(cc, None)
|
||||
def options_for_cc(self, cc: int, op_kind=cutlass.OperationKind.Gemm) -> ArchOptions:
|
||||
return self.registry.get(cc, None)[op_kind]
|
||||
|
||||
@@ -31,5 +31,6 @@
|
||||
#################################################################################################
|
||||
|
||||
from cutlass.op.gemm import Gemm
|
||||
from cutlass.op.conv import Conv2d, Conv2dFprop, Conv2dDgrad, Conv2dWgrad
|
||||
from cutlass.op.gemm_grouped import GroupedGemm
|
||||
from cutlass.op.op import OperationBase
|
||||
|
||||
960
python/cutlass/op/conv.py
Normal file
960
python/cutlass/op/conv.py
Normal file
@@ -0,0 +1,960 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Ease-of-use interface for constructing, compiling, and running CONVs
|
||||
|
||||
The ``Conv2d`` interface is meant to allow one to easily instantiate, compile, and run
|
||||
CONV2D operations in CUTLASS via Python, without specifying many configuration parameters.
|
||||
Under the hood, the interface will select sensible default parameters for the many template
|
||||
parameters for CUTLASS CONVs.
|
||||
|
||||
Note: optimal performance is not to be expected from this interface. To achieve optimal
|
||||
performance, one should specify and tune each configuration parameter.
|
||||
|
||||
The simplest example of using this interface is the following:
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
# A, B, C, and D are torch/numpy/cupy tensor objects
|
||||
plan = cutlass.op.Conv(A, B, C, D)
|
||||
plan.run(stride=(1, 1), padding=(0, 0), dilation=(1, 1))
|
||||
|
||||
One can also use the interface by specifying data types of operands at construction
|
||||
and using different tensor objects with these data types at runtime:
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
# The following is shorthand for:
|
||||
# cutlass.op.Conv2d(kind="fprop",
|
||||
# element_A=torch.float32, element_B=torch.float32,
|
||||
# element_C=torch.float32, element_D=torch.float32,
|
||||
# element_accumulator=torch.float32)
|
||||
plan = cutlass.op.Conv2d(kind="fprop", element=torch.float32)
|
||||
|
||||
A0 = torch.rand((128, 256), dtype=torch.float32, device='cuda')
|
||||
B0 = torch.rand((256, 64), dtype=torch.float32, device='cuda')
|
||||
C0 = torch.zeros((128, 64), dtype=torch.float32, device='cuda')
|
||||
D0 = torch.zeros((128, 64), dtype=torch.float32, device.'cuda')
|
||||
plan.run(A0, B0, C0, D0, stride=(1, 1), padding=(0, 0), dilation=(1, 1))
|
||||
|
||||
A = torch.rand((32, 128), dtype=torch.float32, device='cuda')
|
||||
B = torch.rand((128, 256), dtype=torch.float32, device='cuda')
|
||||
C = torch.zeros((32, 256), dtype=torch.float32, device='cuda')
|
||||
D = torch.zeros((32, 256), dtype=torch.float32, device.'cuda')
|
||||
plan.run(A1, B1, C1, D1, stride=(1, 1), padding=(0, 0), dilation=(1, 1))
|
||||
|
||||
The interface additionally enables one to decouple the compilation of the underlying CUTLASS
|
||||
kernel from its execution:
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
plan = cutlass.op.Conv2d(kind="fprop", element=np.float32)
|
||||
|
||||
# Do other work...
|
||||
|
||||
plan.run(A0, B0, C0, D0, stride=(1, 1), padding=(0, 0), dilation=(1, 1))
|
||||
|
||||
# Do other work...
|
||||
|
||||
plan.run(A1, B1, C1, D1, stride=(1, 1), padding=(0, 0), dilation=(1, 1))
|
||||
|
||||
Elementwise activation functions are easily fused to the GEMM via the interface:
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
plan = cutlass.op.Conv2d(kind="fprop", element=np.float32)
|
||||
plan.activation = cutlass.epilogue.relu
|
||||
|
||||
Operations can also be run asynchronously:
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
plan = cutlass.op.Conv2d(kind="fprop", element=np.float32)
|
||||
args = plan.run()
|
||||
|
||||
# Do other work...
|
||||
|
||||
args.sync()
|
||||
"""
|
||||
|
||||
import cutlass_bindings
|
||||
import cutlass
|
||||
from cutlass import epilogue
|
||||
from cutlass.backend import compiler
|
||||
from cutlass.backend.conv2d_operation import Conv2dArguments, Conv2dOperation
|
||||
from cutlass.backend.reduction_operation import ReductionOperation, ReductionArguments
|
||||
from cutlass.backend.library import TensorDescription, TileDescription
|
||||
from cutlass.op.op import OperationBase
|
||||
from cutlass.utils import check, datatypes
|
||||
|
||||
class Conv2d(OperationBase):
|
||||
"""
|
||||
Constructs a ``Conv2d`` object.
|
||||
|
||||
The convolution kind (fprop, wgrad, degrad), the data types of operands A, B, and C,
|
||||
along with the data type of output D and that used for accumulation, are bound to the ``Conv``
|
||||
object throughout its lifetime -- these are not to be changed after a ``Conv2d`` has been constructed.
|
||||
|
||||
The constructor has optional parameters for flexibly setting these parameters. The following
|
||||
constructors are equivalent:
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
# Use F32 for A, B, C, D, and accumulation in fprop
|
||||
|
||||
# Use the generic ``element`` parameter to concisely set all data types for operands to the same values.
|
||||
Conv2d(kind="fprop", element=cutlass.DataType.f32)
|
||||
|
||||
# Explicitly specify the data types to use for A, B, C, and D.
|
||||
Conv2d(kind="fprop", element_A=cutlass.DataType.f32, element_B=cutlass.DataType.f32,
|
||||
element_C=cutlass.DataType.f32, element_D=cutlass.DataType.f32)
|
||||
|
||||
# Set the data types and elements from existing tensors. Note that one can use different tensors when
|
||||
# executing GEMM via the ``run()`` method than passed in here (though those passed in to ``run()`` must
|
||||
# have the same data type as those passed in here).
|
||||
# A, B, C, and D are torch.Tensor objects of type torch.float32 under the channel-last layout
|
||||
Conv2d(kind="fprop", A=A, B=B, C=C, D=D)
|
||||
|
||||
# Explicitly specify the data type for only some of A, B, C, and D. Unspecified data types will inherit
|
||||
# those passed in via the generic ``element``
|
||||
Conv2d(kind="fprop", element_A=cutlass.DataType.f32, element_accumulator=cutlass.DataType.f32,
|
||||
element=cutlass.DataType.f32)
|
||||
|
||||
The order of precedence for the setting of the data type for a given operand/output is as follows:
|
||||
1) If the tensor type is specified (e.g., ``A``), use the data type inferred from this tensor
|
||||
2) Otherwise, if the data type (e.g., ``element_A``) is specified, use those
|
||||
3) Otherwise, use the generic values (e.g., ``element``)
|
||||
|
||||
:param kind: the convolution kind (i.e. fprop, wgrad, and dgrad)
|
||||
:type kind: str
|
||||
:param A: tensor representing data type of operand A
|
||||
:param B: tensor representing data type of operand B
|
||||
:param C: tensor representing data type of operand C
|
||||
:param D: tensor representing data type of operand D
|
||||
:param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
|
||||
:param beta: scalar parameter beta from GEMM operation that scales operand C
|
||||
:param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type
|
||||
:type element: cutlass.DataType
|
||||
:param element_A: data type to be used for operand A
|
||||
:type element_A: cutlass.DataType
|
||||
:param element_B: data type to be used for operand B
|
||||
:type element_B: cutlass.DataType
|
||||
:param element_C: data type to be used for operand C
|
||||
:type element_C: cutlass.DataType
|
||||
:param element_D: data type to be used for operand D
|
||||
:type element_D: cutlass.DataType
|
||||
:param element_accumulator: data type to be used in accumulation of the product of operands A and B
|
||||
:type element_accumulator: cutlass.DataType
|
||||
:param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90
|
||||
:type cc: int
|
||||
:param kernel_cc: compute capability of kernels to generate. For example, if running on SM90, but desiring to use a CUTLASS 2.x-style Ampere kernel, this should be set to 80
|
||||
:type kernel_cc: int
|
||||
"""
|
||||
def __init__(
|
||||
self, kind="fprop",
|
||||
A=None, B=None, C=None, D=None, alpha=1.0, beta=0.0,
|
||||
element=None,
|
||||
element_A=None, element_B=None, element_C=None, element_D=None,
|
||||
element_accumulator=None,
|
||||
cc: int = None, kernel_cc: int = None
|
||||
):
|
||||
super().__init__(cc=cc, kernel_cc=kernel_cc, operation_kind=cutlass.OperationKind.Conv2d)
|
||||
# Verify the kernel cc
|
||||
if self.current_cc == 90:
|
||||
# The Conv2d kernel on Hopper (SM90) is currently unsupported
|
||||
# Revert to use SM80-tagged kernels
|
||||
cutlass.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.")
|
||||
self.specified_kernel_cc = 80
|
||||
self._reset_options(80)
|
||||
|
||||
# The arch is used in testing
|
||||
self.arch = self.current_cc
|
||||
self.name = "conv2d" + kind
|
||||
|
||||
# The convolution kind. (concept: cutlass_bindings.conv.Operator)
|
||||
self.conv_kind = getattr(cutlass_bindings.conv.Operator, kind)
|
||||
|
||||
# The element types (concept: cutlass library types) of A, B, C, and D
|
||||
elements = []
|
||||
layouts = []
|
||||
|
||||
# Complete the data types based on user-provided arguments
|
||||
for elt, tens, name in zip([element_A, element_B, element_C, element_D],
|
||||
[A, B, C, D],
|
||||
["A", "B", "C", "D"]):
|
||||
if elt is not None and tens is not None:
|
||||
raise Exception(f'Must not specify both element_{name} and tensor {name}')
|
||||
if elt is None and tens is None and element is None:
|
||||
raise Exception(f'Must specify one of element_{name}, tensor {name}, or generic element.')
|
||||
|
||||
elt_to_set = None
|
||||
lay_to_set = None
|
||||
|
||||
if tens is not None:
|
||||
elt_to_set, _ = datatypes.get_datatype_and_layout(tens)
|
||||
else:
|
||||
elt_to_set = elt if elt is not None else element
|
||||
|
||||
assert elt_to_set is not None
|
||||
|
||||
# Currently we only support layout TensorNHWC
|
||||
lay_to_set = cutlass.LayoutType.TensorNHWC
|
||||
elements.append(datatypes.library_type(elt_to_set))
|
||||
layouts.append(lay_to_set)
|
||||
|
||||
self._element_a, self._element_b, self._element_c, self._element_d = elements
|
||||
self._layout_a, self._layout_b, self._layout_c, self._layout_d = layouts
|
||||
|
||||
self.A, self.B, self.C, self.D, self.alpha, self.beta = A, B, C, D, alpha, beta
|
||||
|
||||
if element_accumulator is None:
|
||||
self._element_accumulator = self._element_c
|
||||
else:
|
||||
self._element_accumulator = datatypes.library_type(element_accumulator)
|
||||
|
||||
# Default inputs if none is supplied in run()
|
||||
self.A = A
|
||||
self.B = B
|
||||
self.C = C
|
||||
self.D = D
|
||||
|
||||
self.alpha = alpha
|
||||
self.beta = beta
|
||||
|
||||
# We only specify the stride of the swizzling functor here
|
||||
# The actual swizzling functor is determined in run based on conv_kind and stride
|
||||
self._swizzling_stride = 1
|
||||
|
||||
# Arguments that will be set to default value in _reset_operations
|
||||
# The default tile_description and op_class are fetched from manifest of cutlass library
|
||||
self._tile_description = None
|
||||
self.op_class = None
|
||||
# The default identity epilogue will be created
|
||||
self.epilogue_functor = None
|
||||
|
||||
self._reset_operations()
|
||||
|
||||
# Arguments that will be determined online based on arguments of "run"
|
||||
# based on stride, input/output channels, alignment, and conv_kind
|
||||
self._iterator_algorithm = None
|
||||
self._stride_support = None
|
||||
|
||||
def _reset_operations(self, reset_epilogue: bool = True):
|
||||
# Set the default op class
|
||||
datatype_comb = (self._element_a, self._element_b, self._element_accumulator)
|
||||
layout_comb = (self._layout_a, self._layout_b)
|
||||
|
||||
self.possible_op_classes = self.options.supporting_opclasses(
|
||||
self._element_a, self._element_b, self._element_accumulator,
|
||||
self._layout_a, self._layout_b
|
||||
)
|
||||
|
||||
if cutlass.OpcodeClass.TensorOp in self.possible_op_classes:
|
||||
self.opclass = cutlass.OpcodeClass.TensorOp
|
||||
elif cutlass.OpcodeClass.Simt in self.possible_op_classes:
|
||||
self.opclass = cutlass.OpcodeClass.Simt
|
||||
else:
|
||||
raise Exception(f'No kernel configuration found for supported data type and layout '
|
||||
f'combination {datatype_comb}x{layout_comb}')
|
||||
|
||||
if reset_epilogue:
|
||||
self._reset_epilogue_functor_activation(epilogue.identity)
|
||||
|
||||
self.alignment_pref_A = min(
|
||||
128 // cutlass.DataTypeSize[self._element_a], max(self.possible_operations.alignments))
|
||||
self.alignment_pref_B = min(
|
||||
128 // cutlass.DataTypeSize[self._element_b], max(self.possible_operations.alignments))
|
||||
self.alignment_pref_C = min(
|
||||
128 // cutlass.DataTypeSize[self._element_c], max(self.possible_operations.alignments))
|
||||
|
||||
#
|
||||
# Tile description Related
|
||||
#
|
||||
|
||||
@property
|
||||
def tile_description(self) -> TileDescription:
|
||||
"""
|
||||
Returns the tile description
|
||||
"""
|
||||
return self._tile_description
|
||||
|
||||
@tile_description.setter
|
||||
def tile_description(
|
||||
self, td=None):
|
||||
"""
|
||||
Set the tile description
|
||||
|
||||
:param td: tile description
|
||||
:type td: cutlass.backend.TileDescription, or a dict with keys
|
||||
{
|
||||
"threadblock_shape": [int, int, int],
|
||||
"warp_count": [int, int, int],
|
||||
"stages": int,
|
||||
"instruction_shape": [int, int, int] (optional),
|
||||
"cluster_shape": [int, int, int] (optional)
|
||||
}
|
||||
"""
|
||||
if td is None:
|
||||
return
|
||||
if isinstance(td, dict):
|
||||
if self._tile_description is None:
|
||||
alignment = list(self.possible_operations.kernels_by_alignment.keys())[0]
|
||||
op = self.possible_operations.operations(alignment)[0]
|
||||
self._tile_description = datatypes.td_from_profiler_op(op)
|
||||
if "cluster_shape" in td.keys():
|
||||
if td["cluster_shape"] != [1, 1, 1]:
|
||||
cutlass.logger.warning("Conv2d currently only support 'cluster_shape'=[1, 1, 1]'.")
|
||||
td["cluster_shape"] = [1, 1, 1]
|
||||
td = self._tile_description.clone_and_update(td)
|
||||
|
||||
valid, msg = self._valid_tile_description(td)
|
||||
if valid:
|
||||
self._tile_description = td
|
||||
else:
|
||||
raise Exception(msg)
|
||||
|
||||
def _valid_tile_description(self, td: TileDescription) -> tuple:
|
||||
"""
|
||||
Checks whether the provided tile description is valid for the given compute capability. At present,
|
||||
this checks the following:
|
||||
|
||||
- Does the tile description use a number of stages supported by the compute capability in question?
|
||||
- Does the tile size requested fit within shared memory?
|
||||
- Are cluster dimensions outside the valid range requested for a given architecture (e.g.,
|
||||
more non-unit cluster dimensions for pre-SM90 architectures)?
|
||||
- Is the kernel schedule being used supported on the architecture in question?
|
||||
|
||||
:param td: tile description to validate
|
||||
:type td: cutlass.backend.TileDescription
|
||||
:return: tuple in which the first element is a bool indicating that the tile description is valid
|
||||
and the second element is a string providing an optional error message.
|
||||
:rtype: tuple
|
||||
"""
|
||||
# Check stage count based on the CC to which we are compiling (self.cc), rather
|
||||
# than the CC from which we find kernels (self.current_cc)
|
||||
valid, msg = check.valid_stage_count(self.cc, td)
|
||||
if not valid:
|
||||
return (valid, msg)
|
||||
|
||||
valid, msg = check.valid_cluster_shape(self.current_cc, td.cluster_shape)
|
||||
if not valid:
|
||||
return (valid, msg)
|
||||
|
||||
return valid, msg
|
||||
|
||||
def tile_descriptions(self) -> list:
|
||||
"""
|
||||
Returns a list of valid tile descriptions for the operations
|
||||
|
||||
:returns: list of valid tile descriptions for the operations
|
||||
:rtype: list
|
||||
"""
|
||||
descriptions = []
|
||||
description_str = []
|
||||
for op in self.possible_operations.all_operations:
|
||||
td = datatypes.td_from_profiler_op(op)
|
||||
if str(td) not in description_str:
|
||||
description_str.append(str(td))
|
||||
descriptions.append(td)
|
||||
return descriptions
|
||||
|
||||
#
|
||||
# Swizzling functor Related
|
||||
#
|
||||
|
||||
@property
|
||||
def swizzling_stride(self):
|
||||
"""
|
||||
Returns the stride of swizzling currently being used by the Conv2d
|
||||
|
||||
:return: swizzing stride
|
||||
"""
|
||||
return self._swizzling_stride
|
||||
|
||||
@swizzling_stride.setter
|
||||
def swizzling_stride(self, stride: int):
|
||||
"""
|
||||
Sets the swizzling functor to the type specified by `swizzling_functor`
|
||||
"""
|
||||
if not isinstance(stride, int):
|
||||
raise Exception(f"Expect integer (1, 2, 4, 8), got {stride}")
|
||||
self._swizzling_stride = stride
|
||||
|
||||
def _propose_swizzling_functor(self, stride):
|
||||
"""
|
||||
Automatically propose the swizzling functor based on the stride
|
||||
"""
|
||||
if self.conv_kind == cutlass_bindings.conv.Operator.dgrad:
|
||||
if stride[0] != 1 or stride[1] != 1:
|
||||
return getattr(cutlass.swizzle, f"StridedDgradIdentitySwizzle{self._swizzling_stride}")
|
||||
|
||||
return getattr(cutlass.swizzle, f"IdentitySwizzle{self._swizzling_stride}")
|
||||
|
||||
#
|
||||
# Iterator Algorithm Related
|
||||
#
|
||||
|
||||
@property
|
||||
def iterator_algorithm(self) -> cutlass_bindings.conv.IteratorAlgorithm:
|
||||
"""
|
||||
Returns the iterator algorithm
|
||||
"""
|
||||
return self._iterator_algorithm
|
||||
|
||||
@iterator_algorithm.setter
|
||||
def iterator_algorithm(self, alg: str):
|
||||
"""
|
||||
Sets the iterator algorithm
|
||||
|
||||
:param alg: The iterator algorithm
|
||||
:type td: string, options: "analytic", "optimized", "few_channels", and "fixed_channels"
|
||||
"""
|
||||
# Check if the iterator algorithm is valid
|
||||
if alg in ["few_channels", "fixed_channels"] and self.conv_kind != cutlass_bindings.conv.Operator.fprop:
|
||||
raise Exception(f"{self.conv_kind} does not support iterator algorithm {alg}.")
|
||||
|
||||
self._iterator_algorithm = getattr(cutlass_bindings.conv.IteratorAlgorithm, alg)
|
||||
|
||||
def _propose_iterator_algorithm(self, problem_size, alignment_a, alignment_b) -> cutlass_bindings.conv.IteratorAlgorithm:
|
||||
"""
|
||||
Propose a valid iterator algorithm based on problem size and alignment
|
||||
"""
|
||||
if self.conv_kind == cutlass_bindings.conv.Operator.fprop:
|
||||
# Check whether the fixed channel is applicable
|
||||
if problem_size.C == alignment_a:
|
||||
return cutlass_bindings.conv.IteratorAlgorithm.fixed_channels
|
||||
elif (problem_size.C % alignment_a == 0 and
|
||||
problem_size.R <= 32 and problem_size.S <= 32):
|
||||
return cutlass_bindings.conv.IteratorAlgorithm.optimized
|
||||
else:
|
||||
return cutlass_bindings.conv.IteratorAlgorithm.analytic
|
||||
elif self.conv_kind == cutlass_bindings.conv.Operator.dgrad:
|
||||
if (problem_size.K % alignment_a == 0 and
|
||||
problem_size.R <= 32 and problem_size.S <= 32 and
|
||||
problem_size.C % alignment_b == 0):
|
||||
return cutlass_bindings.conv.IteratorAlgorithm.optimized
|
||||
else:
|
||||
return cutlass_bindings.conv.IteratorAlgorithm.analytic
|
||||
elif self.conv_kind == cutlass_bindings.conv.Operator.wgrad:
|
||||
if (problem_size.K % alignment_a == 0 and
|
||||
problem_size.C % alignment_b == 0):
|
||||
return cutlass_bindings.conv.IteratorAlgorithm.optimized
|
||||
else:
|
||||
return cutlass_bindings.conv.IteratorAlgorithm.analytic
|
||||
|
||||
def _validate_iterator_algorithm(self, iterator_algorithm, problem_size, alignment_a, alignment_b) -> bool:
|
||||
"""
|
||||
Validate whether the user provide iterator algorithm works for the given problem size
|
||||
"""
|
||||
if self.conv_kind == cutlass_bindings.conv.Operator.fprop:
|
||||
if iterator_algorithm == cutlass_bindings.conv.IteratorAlgorithm.fixed_channels:
|
||||
return problem_size.C == alignment_a
|
||||
elif iterator_algorithm == cutlass_bindings.conv.IteratorAlgorithm.optimized:
|
||||
return (problem_size.C % alignment_a == 0 and
|
||||
problem_size.R <= 32 and problem_size.S <= 32)
|
||||
elif iterator_algorithm == cutlass_bindings.conv.IteratorAlgorithm.few_channels:
|
||||
return problem_size.C % alignment_a == 0
|
||||
elif self.conv_kind == cutlass_bindings.conv.Operator.dgrad:
|
||||
if iterator_algorithm == cutlass_bindings.conv.IteratorAlgorithm.optimized:
|
||||
return (problem_size.K % alignment_a == 0 and
|
||||
problem_size.R <= 32 and problem_size.S <= 32 and
|
||||
problem_size.C % alignment_b == 0)
|
||||
elif self.conv_kind == cutlass_bindings.conv.Operator.wgrad:
|
||||
if iterator_algorithm == cutlass_bindings.conv.IteratorAlgorithm.optimized:
|
||||
return (problem_size.K % alignment_a == 0 and
|
||||
problem_size.C % alignment_b == 0)
|
||||
|
||||
return True
|
||||
|
||||
#
|
||||
# Stride Support Related
|
||||
#
|
||||
|
||||
def _propose_stride_support(self, stride):
|
||||
if self.conv_kind == cutlass_bindings.conv.Operator.dgrad:
|
||||
if stride[0] == 1 and stride[1] == 1:
|
||||
return cutlass.backend.library.StrideSupport.Unity
|
||||
|
||||
return cutlass.backend.library.StrideSupport.Strided
|
||||
|
||||
#
|
||||
# Construct and Compilation
|
||||
#
|
||||
|
||||
def construct(
|
||||
self, tile_description: TileDescription = None,
|
||||
alignment_A: int = None, alignment_B: int = None, alignment_C: int = None,
|
||||
iterator_algorithm: cutlass_bindings.conv.IteratorAlgorithm = None,
|
||||
stride_support = None, swizzling_functor: cutlass.swizzle = None,
|
||||
epilogue_functor=None) -> cutlass.backend.Conv2dOperation:
|
||||
"""
|
||||
Constructs a ``cutlass.backend.Conv2dOperation`` based on the input parameters and current
|
||||
kernel specification of the ``Conv2d`` object.
|
||||
|
||||
:param tile_description: tile description specifying shapes and operand types to use in the kernel
|
||||
:type tile_description: cutlass.backend.TileDescription
|
||||
:param alignment_A: alignment of operand A
|
||||
:type alignment_A: int
|
||||
:param alignment_B: alignment of operand B
|
||||
:type alignment_B: int
|
||||
:param alignment_C: alignment of operand C
|
||||
:type alignment_C: int
|
||||
:param iterator_algorithm: the iterator algorithm used
|
||||
:type iterator_algorithm: cutlass_bindings.conv.IteratorAlgorithm
|
||||
:param stride_support: the stride support of dgrad
|
||||
:type stride_support: cutlass.backend.library.StrideSupport
|
||||
:param swizzling_functor: the swizzling functor
|
||||
:type swizzling_functor: cutlass.swizzle
|
||||
:param epilogue_functor: the epilogue functor
|
||||
|
||||
:return: operation that was constructed
|
||||
:rtype: cutlass.backend.Conv2dOperation
|
||||
"""
|
||||
# Get alignment
|
||||
alignment_A = check.alignment_or_default(alignment_A, self.alignment_pref_A)
|
||||
alignment_B = check.alignment_or_default(alignment_B, self.alignment_pref_B)
|
||||
alignment_C = check.alignment_or_default(alignment_C, self.alignment_pref_C)
|
||||
|
||||
tensor_A = TensorDescription(
|
||||
datatypes.binding_type(self._element_a),
|
||||
datatypes.binding_layout(self._layout_b),
|
||||
alignment_A
|
||||
)
|
||||
tensor_B = TensorDescription(
|
||||
datatypes.binding_type(self._element_b),
|
||||
datatypes.binding_layout(self._layout_b),
|
||||
alignment_B
|
||||
)
|
||||
tensor_C = TensorDescription(
|
||||
datatypes.binding_type(self._element_c),
|
||||
datatypes.binding_layout(self._layout_c),
|
||||
alignment_C
|
||||
)
|
||||
|
||||
if tile_description is None:
|
||||
if self.tile_description is not None:
|
||||
tile_description = self.tile_description
|
||||
else:
|
||||
op = self.possible_operations.operations(alignment_A)[0]
|
||||
tile_description = datatypes.td_from_profiler_op(op)
|
||||
else:
|
||||
valid, err_str = self._valid_tile_description(tile_description)
|
||||
if not valid:
|
||||
raise Exception(f"Invalid tile description. {err_str}")
|
||||
self.tile_description = tile_description
|
||||
|
||||
if iterator_algorithm is None:
|
||||
# If the iterator algorithm is already set
|
||||
if self.iterator_algorithm is not None:
|
||||
iterator_algorithm = self.iterator_algorithm
|
||||
else:
|
||||
# Otherwise, we conservatively use the analytic iterator for correctness
|
||||
iterator_algorithm = cutlass_bindings.conv.IteratorAlgorithm.analytic
|
||||
|
||||
if stride_support is None:
|
||||
# If the stride support is already set
|
||||
if self._stride_support is not None:
|
||||
stride_support = self._stride_support
|
||||
else:
|
||||
# Otherwise, we assume strided
|
||||
stride_support = cutlass.backend.library.StrideSupport.Strided
|
||||
|
||||
if swizzling_functor is None:
|
||||
# If the swizzling functor is already set
|
||||
swizzling_functor = self._propose_swizzling_functor(stride=(2, 2))
|
||||
|
||||
if epilogue_functor is None:
|
||||
if self.epilogue_functor is not None:
|
||||
epilogue_functor = self.epilogue_functor
|
||||
else:
|
||||
epilogue_functor = self._create_epilogue_functor_activation(self._activation)
|
||||
|
||||
# Reset the alignment of the epilogue functor
|
||||
epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, epilogue_functor)
|
||||
|
||||
operation = Conv2dOperation(
|
||||
conv_kind=self.conv_kind,
|
||||
iterator_algorithm=iterator_algorithm,
|
||||
arch=self.current_cc,
|
||||
tile_description=tile_description,
|
||||
A=tensor_A, B=tensor_B, C=tensor_C,
|
||||
stride_support=stride_support,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=swizzling_functor,
|
||||
)
|
||||
|
||||
return operation
|
||||
|
||||
def compile(self, tile_description: TileDescription = None,
|
||||
alignment_A: int = None, alignment_B: int = None, alignment_C: int = None,
|
||||
iterator_algorithm: cutlass_bindings.conv.IteratorAlgorithm = None,
|
||||
stride_support = None, swizzling_functor: cutlass.swizzle = None,
|
||||
epilogue_functor = None, print_module: bool = False) -> cutlass.backend.Conv2dOperation:
|
||||
"""
|
||||
Emits and compiles the kernel currently specified. If ``tile_description`` and any
|
||||
of the ``alignment`` parameters are set, the kernel will be chosen using this
|
||||
tile description and alignments. Otherwise, a default tile description and alignment
|
||||
will be used.
|
||||
|
||||
::param tile_description: tile description specifying shapes and operand types to use in the kernel
|
||||
:type tile_description: cutlass.backend.TileDescription
|
||||
:param alignment_A: alignment of operand A
|
||||
:type alignment_A: int
|
||||
:param alignment_B: alignment of operand B
|
||||
:type alignment_B: int
|
||||
:param alignment_C: alignment of operand C
|
||||
:type alignment_C: int
|
||||
:param iterator_algorithm: the iterator algorithm used
|
||||
:type iterator_algorithm: cutlass_bindings.conv.IteratorAlgorithm
|
||||
:param stride_support: the stride support of dgrad
|
||||
:type stride_support: cutlass.backend.library.StrideSupport
|
||||
:param swizzling_functor: the swizzling functor
|
||||
:type swizzling_functor: cutlass.swizzle
|
||||
:param epilogue_functor: the epilogue functor
|
||||
|
||||
:return: operation that was compiled
|
||||
:rtype: cutlass.backend.Conv2dOperation
|
||||
"""
|
||||
|
||||
self.operation = self.construct(
|
||||
tile_description, alignment_A, alignment_B, alignment_C,
|
||||
iterator_algorithm, stride_support, swizzling_functor, epilogue_functor)
|
||||
|
||||
if print_module:
|
||||
print(self.operation.rt_module.emit())
|
||||
|
||||
compiler.add_module([self.operation,])
|
||||
return self.operation
|
||||
|
||||
#
|
||||
# Run Related
|
||||
#
|
||||
|
||||
def _verify_type_and_layout(self, tensor, ref_type, ref_layout, name):
|
||||
"""
|
||||
Verifies that ``tensor`` has data type ``ref_type`` and layout ``ref_layout``. An exception
|
||||
is raised if it does not.
|
||||
|
||||
:param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
|
||||
:type tensor: numpy/cupy/torch array/tensor object
|
||||
:param ref_dtype: data type for the tensor that this object was initialized to
|
||||
:param name: identifier of the tensor to verify. Used in raising exceptions
|
||||
:type name: str
|
||||
"""
|
||||
dtype, _ = datatypes.get_datatype_and_layout(tensor)
|
||||
if dtype != ref_type:
|
||||
raise Exception(f'Tensor {name} with type and layout {dtype} '
|
||||
f'does not match the expected type of {ref_type}.')
|
||||
|
||||
|
||||
|
||||
def _get_and_verify_conv_problem_size(self, A, B, C, stride, padding, dilation):
|
||||
if self.conv_kind == cutlass_bindings.conv.Operator.fprop:
|
||||
input = A
|
||||
weight = B
|
||||
output = C
|
||||
output_tensor = "C"
|
||||
elif self.conv_kind == cutlass_bindings.conv.Operator.dgrad:
|
||||
output = A
|
||||
weight = B
|
||||
input = C
|
||||
output_tensor = "A"
|
||||
elif self.conv_kind == cutlass_bindings.conv.Operator.wgrad:
|
||||
output = A
|
||||
input = B
|
||||
weight = C
|
||||
output_tensor = "A"
|
||||
else:
|
||||
raise Exception(f"Convolution kind {self.conv_kind} is not supported")
|
||||
|
||||
N_, H_, W_, C_ = datatypes.get_tensor_shape(input)
|
||||
K_, R_, S_, _ = datatypes.get_tensor_shape(weight)
|
||||
_, P_, Q_, _ = datatypes.get_tensor_shape(output)
|
||||
|
||||
problem_size = cutlass_bindings.conv.Conv2dProblemSize(
|
||||
cutlass_bindings.Tensor4DCoord(N_, H_, W_, C_),
|
||||
cutlass_bindings.Tensor4DCoord(K_, R_, S_, C_),
|
||||
cutlass_bindings.Tensor4DCoord(padding[0], padding[0], padding[1], padding[1]),
|
||||
cutlass_bindings.MatrixCoord(stride[0], stride[1]),
|
||||
cutlass_bindings.MatrixCoord(dilation[0], dilation[1]),
|
||||
cutlass_bindings.conv.Mode.cross_correlation,
|
||||
1, 1
|
||||
)
|
||||
|
||||
if P_ != problem_size.P or Q_ != problem_size.Q:
|
||||
raise Exception(
|
||||
f"Tensor {output_tensor} size should be ({N_}, {problem_size.P}, {problem_size.Q}, {K_}), got ({N_}, {P_}, {Q_}, {K_})")
|
||||
|
||||
return problem_size
|
||||
|
||||
def run(self, A=None, B=None, C=None, D=None,
|
||||
stride=(1, 1), padding=(0, 0), dilation=(1, 1),
|
||||
alpha=None, beta=None,
|
||||
split_k=("serial", 1), sync: bool = True,
|
||||
print_module: bool = False) -> Conv2dArguments:
|
||||
"""
|
||||
Runs the kernel currently specified. If it has not already been, the kernel is emitted and
|
||||
compiled. Tensors holding operands and outputs of the kernel are sourced either from the
|
||||
``A``, ``B``, ``C``, ``D``, ``alpha``, and ``beta``
|
||||
parameters provided in the call, or from those
|
||||
passed in on the construction of this object -- one of the two must be specified.
|
||||
|
||||
By default, this call returns only once the kernel has completed. To launch the kernel
|
||||
and immediately return, set ``sync=False``. In this case, it is the responsibility of the
|
||||
caller to syncrhonize the results of the kernel before attempting to access outputs
|
||||
by calling ``sync()`` on the arguments returned from this call.
|
||||
|
||||
:param A: tensor representing data type and layout of operand A
|
||||
:param B: tensor representing data type and layout of operand B
|
||||
:param C: tensor representing data type and layout of operand C
|
||||
:param D: tensor representing data type and layout of operand D
|
||||
:param stride: (stride_h, stride_w) describing the convolution stride. Default: (1, 1)
|
||||
:param padding: (pad_h, pad_w) describing the convolution padding. Default: (0, 0)
|
||||
:param dilation: (dilation_h, dilation_w) describing the dilation of convolution. Default: (1, 1)
|
||||
:param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
|
||||
:param beta: scalar parameter beta from GEMM operation that scales operand C
|
||||
:param split_k: a tuple (split_k_mode, split_k_slices)
|
||||
:param sync: whether the call should wait for the kernel to complete before returning
|
||||
:type sync: bool
|
||||
:param print_module: whether to print the emitted C++ code
|
||||
:type print_module: bool
|
||||
|
||||
:return: arguments passed in to the kernel
|
||||
:rtype: cutlass.backend.Conv2dArguments
|
||||
"""
|
||||
A = self._verify_tensor(A, self.A, self._element_a, self._layout_a, "A")
|
||||
B = self._verify_tensor(B, self.B, self._element_b, self._layout_b, "B")
|
||||
C = self._verify_tensor(C, self.C, self._element_c, self._layout_c, "C")
|
||||
D = self._verify_tensor(D, self.D, self._element_d, self._layout_d, "D")
|
||||
alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha")
|
||||
beta = self._verify_scalar(beta, self.beta, self._element_c, "beta")
|
||||
|
||||
# handle the case when there is no C
|
||||
if C is None:
|
||||
if beta != 0:
|
||||
raise Exception(f"With beta {beta} != 0, C has to be provided.")
|
||||
else:
|
||||
C = D
|
||||
|
||||
# Construct problem size based on input
|
||||
# It also verifies whether the A, B, C, D, stride, padding, and dilation are matching
|
||||
problem_size = self._get_and_verify_conv_problem_size(A, B, C, stride, padding, dilation)
|
||||
|
||||
# Propose stride support based on input
|
||||
stride_support = self._propose_stride_support(stride)
|
||||
|
||||
# Propose swizzling functor
|
||||
swizzling_functor = self._propose_swizzling_functor(stride)
|
||||
|
||||
# Get the alignment
|
||||
alignment_a = self.possible_operations.find_alignment(datatypes.get_tensor_shape(A), self._layout_a)
|
||||
alignment_b = self.possible_operations.find_alignment(datatypes.get_tensor_shape(B), self._layout_b)
|
||||
alignment_c = self.possible_operations.find_alignment(datatypes.get_tensor_shape(C), self._layout_c)
|
||||
|
||||
alignment_a = check.update_alignment(alignment_a, self.alignment_pref_A)
|
||||
alignment_b = check.update_alignment(alignment_b, self.alignment_pref_B)
|
||||
alignment_c = check.update_alignment(alignment_c, self.alignment_pref_C)
|
||||
|
||||
# Propose iterator algorithm based on input
|
||||
if self._iterator_algorithm is None:
|
||||
# Propose a default itertaor algorithm based on the problem size
|
||||
iterator_algorithm = self._propose_iterator_algorithm(problem_size, alignment_a, alignment_b)
|
||||
else:
|
||||
if (self._validate_iterator_algorithm(self._iterator_algorithm, problem_size, alignment_a, alignment_b)):
|
||||
iterator_algorithm = self._iterator_algorithm
|
||||
else:
|
||||
raise Exception(f"Iterator algorithm {self._iterator_algorithm} is invalid for current problem.")
|
||||
|
||||
epilogue_args = [alpha, beta]
|
||||
|
||||
if hasattr(self, "_activation_args"):
|
||||
if isinstance(self._activation_args, list):
|
||||
epilogue_args += self._activation_args
|
||||
else:
|
||||
epilogue_args.append(self._activation_args)
|
||||
|
||||
if split_k[0] == "parallel" and split_k[1] > 1:
|
||||
epilogue_functor = self._create_epilogue_functor_activation(epilogue.identity)
|
||||
else:
|
||||
epilogue_functor = self.epilogue_functor
|
||||
|
||||
# The alignment is determined by the iterator function (I believe)
|
||||
self.compile(tile_description=self.tile_description, alignment_A=alignment_a, alignment_B=alignment_b,
|
||||
alignment_C=alignment_c, iterator_algorithm=iterator_algorithm, stride_support=stride_support,
|
||||
swizzling_functor=swizzling_functor, epilogue_functor=epilogue_functor, print_module=print_module)
|
||||
|
||||
# Create reduction operation for parallel split-k
|
||||
if split_k[0] == "parallel" and split_k[1] > 1:
|
||||
epilogue_functor_reduction = self._reset_epilogue_functor_alignment(alignment_c, self.epilogue_functor)
|
||||
self.reduction_operation = ReductionOperation(
|
||||
shape=cutlass_bindings.MatrixCoord(4, 32 * alignment_c), C=self.operation.C,
|
||||
element_accumulator=datatypes.binding_type(self._element_accumulator),
|
||||
element_compute=datatypes.binding_type(self._element_accumulator),
|
||||
epilogue_functor=epilogue_functor_reduction,
|
||||
count=alignment_c
|
||||
)
|
||||
if print_module:
|
||||
print(self.reduction_operation.rt_module.emit())
|
||||
compiler.add_module([self.reduction_operation,])
|
||||
|
||||
arguments = Conv2dArguments(
|
||||
operation=self.operation, problem_size=problem_size,
|
||||
A=A, B=B, C=C, D=D,
|
||||
output_op=self.operation.epilogue_type(*epilogue_args),
|
||||
split_k_mode=datatypes.getattr_enum(cutlass_bindings.conv.SplitKMode, split_k[0]),
|
||||
split_k_slices=split_k[1]
|
||||
)
|
||||
|
||||
self.operation.run(arguments)
|
||||
|
||||
if split_k[0] == "parallel" and split_k[1] > 1:
|
||||
implicit_gemm_size = cutlass_bindings.conv.implicit_gemm_problem_size(
|
||||
self.conv_kind, arguments.problem_size
|
||||
)
|
||||
reduction_arguments = ReductionArguments(
|
||||
self.reduction_operation,
|
||||
problem_size=[implicit_gemm_size.m(), implicit_gemm_size.n()],
|
||||
partitions=split_k[1],
|
||||
workspace=arguments.ptr_D,
|
||||
destination=D,
|
||||
source=C,
|
||||
output_op=self.reduction_operation.epilogue_type(*epilogue_args)
|
||||
)
|
||||
self.reduction_operation.run(reduction_arguments)
|
||||
|
||||
if sync:
|
||||
if split_k[0] == "parallel" and split_k[1] > 1:
|
||||
reduction_arguments.sync()
|
||||
else:
|
||||
arguments.sync()
|
||||
|
||||
return arguments
|
||||
|
||||
#
|
||||
# Helper functions
|
||||
#
|
||||
@staticmethod
|
||||
def output_size(input_size, weight_size, padding, stride, dilation):
|
||||
problem_size = cutlass_bindings.conv.Conv2dProblemSize(
|
||||
cutlass_bindings.Tensor4DCoord(*input_size),
|
||||
cutlass_bindings.Tensor4DCoord(*weight_size),
|
||||
cutlass_bindings.Tensor4DCoord(padding[0], padding[0], padding[1], padding[1]),
|
||||
cutlass_bindings.MatrixCoord(stride[0], stride[1]),
|
||||
cutlass_bindings.MatrixCoord(dilation[0], dilation[1]),
|
||||
cutlass_bindings.conv.Mode.cross_correlation,
|
||||
1, 1
|
||||
)
|
||||
return (problem_size.N, problem_size.P, problem_size.Q, problem_size.K)
|
||||
|
||||
|
||||
#
|
||||
# Easy to use interfaces for fprop, wgrad, and dgrad
|
||||
#
|
||||
|
||||
class Conv2dFprop(Conv2d):
|
||||
def __init__(
|
||||
self,
|
||||
input=None, weight=None, C=None, output=None, alpha=1, beta=0,
|
||||
element=None,
|
||||
element_input=None, element_weight=None, element_C=None, element_output=None,
|
||||
element_accumulator=None,
|
||||
cc: int = None, kernel_cc: int = None):
|
||||
A, B, D = input, weight, output
|
||||
element_A, element_B, element_D = element_input, element_weight, element_output
|
||||
super().__init__(
|
||||
"fprop", A, B, C, D, alpha, beta, element,
|
||||
element_A, element_B, element_C, element_D,
|
||||
element_accumulator, cc, kernel_cc)
|
||||
|
||||
def run(
|
||||
self, input=None, weight=None, C=None, output=None, alpha=None, beta=None,
|
||||
stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
|
||||
sync: bool = True, print_module: bool = False) -> Conv2dArguments:
|
||||
|
||||
A, B, D = input, weight, output
|
||||
return super().run(
|
||||
A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module)
|
||||
|
||||
|
||||
class Conv2dDgrad(Conv2d):
|
||||
def __init__(
|
||||
self,
|
||||
grad_output=None, weight=None, C=None, grad_input=None, alpha=1, beta=0,
|
||||
element=None,
|
||||
element_grad_output=None, element_weight=None, element_C=None, element_grad_input=None,
|
||||
element_accumulator=None,
|
||||
cc: int = None, kernel_cc: int = None):
|
||||
A, B, D = grad_output, weight, grad_input
|
||||
element_A, element_B, element_D = element_grad_output, element_weight, element_grad_input
|
||||
super().__init__(
|
||||
"dgrad", A, B, C, D, alpha, beta, element,
|
||||
element_A, element_B, element_C, element_D,
|
||||
element_accumulator, cc, kernel_cc)
|
||||
|
||||
def run(self, grad_output=None, weight=None, C=None, grad_input=None, alpha=None, beta=None,
|
||||
stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
|
||||
sync: bool = True, print_module: bool = False) -> Conv2dArguments:
|
||||
#
|
||||
A, B, D = grad_output, weight, grad_input
|
||||
return super().run(
|
||||
A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module)
|
||||
|
||||
|
||||
class Conv2dWgrad(Conv2d):
|
||||
def __init__(
|
||||
self,
|
||||
grad_output=None, input=None, C=None, grad_weight=None, alpha=1, beta=0,
|
||||
element=None,
|
||||
element_grad_output=None, element_input=None, element_C=None, element_grad_weight=None,
|
||||
element_accumulator=None,
|
||||
cc: int = None, kernel_cc: int = None):
|
||||
A, B, D = grad_output, input, grad_weight
|
||||
element_A, element_B, element_D = element_grad_output, element_input, element_grad_weight
|
||||
super().__init__(
|
||||
"wgrad", A, B, C, D, alpha, beta, element,
|
||||
element_A, element_B, element_C, element_D,
|
||||
element_accumulator, cc, kernel_cc)
|
||||
|
||||
def run(self, grad_output=None, input=None, C=None, grad_weight=None, alpha=None, beta=None,
|
||||
stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
|
||||
sync: bool = True, print_module: bool = False) -> Conv2dArguments:
|
||||
#
|
||||
A, B, D = grad_output, input, grad_weight
|
||||
return super().run(
|
||||
A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module)
|
||||
@@ -287,108 +287,6 @@ class Gemm(OperationBase):
|
||||
if reset_epilogue:
|
||||
self._reset_epilogue_functor_activation(epilogue.identity)
|
||||
|
||||
def _reset_epilogue_functor_activation(self, activation):
|
||||
if self.epilogue_functor is None:
|
||||
if self.op_class == cutlass.OpcodeClass.Simt:
|
||||
elements_per_access = 1
|
||||
else:
|
||||
elements_per_access = 128 // cutlass.DataTypeSize[self._element_c]
|
||||
else:
|
||||
elements_per_access = self.epilogue_functor.epilogue_vector_length
|
||||
|
||||
if not self.specified_kernel_cc:
|
||||
if self.current_cc == 90 and activation != epilogue.identity:
|
||||
# CUTLASS 3.0 kernels currently only support identity activation. If one requests a non-identity activation,
|
||||
# revert to using a CUTLASS 2.x kernel by using SM80-tagged kernels.
|
||||
cutlass.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.")
|
||||
self._reset_options(80)
|
||||
self._reset_operations(reset_epilogue=False)
|
||||
elif (self.cc == 90 and self.current_cc != 90 and activation == epilogue.identity):
|
||||
# SM80 fallback kernels are currently used. Since an identity activation is requested,
|
||||
# we can switch back to using SM90 kernels.
|
||||
self._reset_options(90)
|
||||
self._reset_operations(reset_epilogue=False)
|
||||
else:
|
||||
if self.current_cc == 90 and activation != epilogue.identity:
|
||||
raise Exception("Epilogues with elementwise fusion are not currently supported "
|
||||
"in the Python interface for 3.x kernels. To use 2.x kernels "
|
||||
"with fused elementwise epilogues, do not set the `kernel_cc` "
|
||||
"parameter when constructing the Gemm object.")
|
||||
|
||||
self.epilogue_functor = epilogue.get_activation_epilogue(
|
||||
activation,
|
||||
datatypes.binding_type(self._element_c),
|
||||
elements_per_access,
|
||||
datatypes.binding_type(self._element_accumulator),
|
||||
datatypes.binding_type(self._element_accumulator),
|
||||
)
|
||||
|
||||
def _reset_epilogue_functor_alignment(self, alignment):
|
||||
if self.epilogue_functor is None or not hasattr(self.epilogue_functor, 'activation_functor'):
|
||||
activation = epilogue.identity
|
||||
else:
|
||||
activation = type(self.epilogue_functor.activation_functor)
|
||||
|
||||
self.epilogue_functor = epilogue.get_activation_epilogue(
|
||||
activation,
|
||||
datatypes.binding_type(self._element_c),
|
||||
alignment,
|
||||
datatypes.binding_type(self._element_accumulator),
|
||||
datatypes.binding_type(self._element_accumulator),
|
||||
)
|
||||
|
||||
@property
|
||||
def activation(self):
|
||||
"""
|
||||
Returns the type of the current activation function used
|
||||
"""
|
||||
return type(self.epilogue_functor.activation_functor)
|
||||
|
||||
@activation.setter
|
||||
def activation(self, act):
|
||||
"""
|
||||
Sets the type of the activation function to use
|
||||
"""
|
||||
self._reset_epilogue_functor_activation(act)
|
||||
|
||||
@property
|
||||
def opclass(self) -> cutlass.OpcodeClass:
|
||||
"""
|
||||
Returns the opcode class currently in use by the GEMM
|
||||
|
||||
:return: opcode class currently in use
|
||||
:rtype: cutlass.OpcodeClass
|
||||
"""
|
||||
return self.op_class
|
||||
|
||||
@opclass.setter
|
||||
def opclass(self, oc: cutlass.OpcodeClass):
|
||||
"""
|
||||
Sets the opcode class to use in the GEMM. If the opcode class is not supported under
|
||||
the given compute capability and element/layout combinations of the GEMM, an exception is raised.
|
||||
"""
|
||||
if oc in self.possible_op_classes:
|
||||
self.op_class = oc
|
||||
else:
|
||||
raise Exception(
|
||||
f'Unsupported operation class {oc} for CC {self.cc} and data type combination '
|
||||
f'({self._element_a}, {self._element_b}, {self._element_accumulator}) and '
|
||||
f'layout combination ({self._layout_a}, {self._layout_b}).')
|
||||
|
||||
# Changing the op class changes the elements per access in the epilogue. Reset this.
|
||||
if self.op_class == cutlass.OpcodeClass.Simt:
|
||||
elements_per_access = 1
|
||||
else:
|
||||
elements_per_access = 128 // cutlass.DataTypeSize[self._element_c]
|
||||
|
||||
if self.epilogue_functor is not None:
|
||||
self._reset_epilogue_functor_alignment(elements_per_access)
|
||||
|
||||
# Changing the op class also changes the possible operations available. Reset these.
|
||||
self.possible_operations = self.options.operations(
|
||||
self.op_class, self._element_a, self._element_b,
|
||||
self._element_accumulator, self._layout_a, self._layout_b)
|
||||
|
||||
@property
|
||||
def swizzling_functor(self):
|
||||
"""
|
||||
@@ -430,7 +328,7 @@ class Gemm(OperationBase):
|
||||
"""
|
||||
# Check stage count based on the CC to which we are compiling (self.cc), rather
|
||||
# than the CC from which we find kernels (self.current_cc)
|
||||
valid, msg = check.valid_stage_count(self.cc, td)
|
||||
valid, msg = check.valid_stage_count(self.cc, td, self._element_c, self._element_d)
|
||||
if not valid:
|
||||
return (valid, msg)
|
||||
|
||||
@@ -438,7 +336,7 @@ class Gemm(OperationBase):
|
||||
if not valid:
|
||||
return (valid, msg)
|
||||
|
||||
valid, msg = check.valid_kernel_schedule(self.current_cc, td.kernel_schedule)
|
||||
valid, msg = check.valid_schedule(self.current_cc, td.kernel_schedule, td.epilogue_schedule, td.tile_scheduler)
|
||||
return valid, msg
|
||||
|
||||
def tile_descriptions(self) -> list:
|
||||
@@ -476,7 +374,7 @@ class Gemm(OperationBase):
|
||||
alignment_B = check.alignment_or_default(alignment_B, alignment_pref_B)
|
||||
alignment_C = check.alignment_or_default(alignment_C, alignment_pref_C)
|
||||
|
||||
self._reset_epilogue_functor_alignment(alignment_C)
|
||||
self.epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, self.epilogue_functor)
|
||||
|
||||
tensor_A = TensorDescription(
|
||||
datatypes.binding_type(self._element_a),
|
||||
@@ -562,68 +460,6 @@ class Gemm(OperationBase):
|
||||
f'does not match the expected type and '
|
||||
f'layout of ({ref_type}, {ref_layout}).')
|
||||
|
||||
def _verify_tensor(self, tensor, ref_tensor, ref_dtype, ref_layout, name):
|
||||
"""
|
||||
Verifies the following properties:
|
||||
1) Either ``tensor`` or ``ref_tensor`` must be set (i.e., not ``None``)
|
||||
2) If ``tensor`` is not ``None``, its datatype and layout must match matches the current versions
|
||||
set by the plan (i.e., those in ``ref_dtype`` and ``ref_layout``)
|
||||
|
||||
If either of these properties does not hold, an exception is raised. If these properties hold and
|
||||
``tensor`` is not ``None``, ``tensor`` is returned. Otherwise, ``ref_tensor`` is returned.
|
||||
|
||||
:param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
|
||||
:type tensor: numpy/cupy/torch array/tensor object
|
||||
:param ref_tensor: object representing a tensor passed in on construction of this object, or ``None`` if no tensor was passed in
|
||||
:type ref_tensor: numpy/cupy/torch array/tensor object
|
||||
:param ref_dtype: data type for the tensor that this object was initialized to
|
||||
:param ref_layout: layout for the tensor that this object was initialized to
|
||||
:param name: identifier of the tensor to verify. Used in raising exceptions
|
||||
:type name: str
|
||||
|
||||
:return: valid tensor object to use
|
||||
:rtype: numpy/cupy/torch array/tensor object
|
||||
"""
|
||||
if tensor is None:
|
||||
if ref_tensor is None:
|
||||
raise Exception(f"Tensor {name} must be set.")
|
||||
return ref_tensor
|
||||
|
||||
self._verify_type_and_layout(tensor, ref_dtype, ref_layout, name)
|
||||
return tensor
|
||||
|
||||
def _verify_scalar(self, scalar, ref_scalar, ref_dtype, name):
|
||||
"""
|
||||
Verifies the following properties:
|
||||
1) Either ``scalar`` or ``ref_scakar`` must be set (i.e., not ``None``)
|
||||
2) If ``scalar`` is not ``None``, its datatype must match matches the current version
|
||||
set by the plan (i.e., those in ``ref_dtype``)
|
||||
|
||||
If either of these properties does not hold, an exception is raised. If these properties hold and
|
||||
``scalar`` is not ``None``, ``scalar`` is returned. Otherwise, ``ref_scalar`` is returned.
|
||||
|
||||
:param scalar: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
|
||||
:type scalar: numpy/cupy/torch scalar
|
||||
:param ref_scalar: object representing a tensor passed in on construction of this object, or ``None`` if no tensor was passed in
|
||||
:type ref_scalar: numpy/cupy/torch scalar
|
||||
:param ref_dtype: data type for the scalar that this object was initialized to
|
||||
:param name: identifier of the scalar to verify. Used in raising exceptions
|
||||
:type name: str
|
||||
|
||||
:return: valid scalar to use
|
||||
:rtype: numpy/cupy/torch scalar
|
||||
"""
|
||||
if scalar is None:
|
||||
if ref_scalar is None:
|
||||
raise Exception(f"Scalar {name} must be set.")
|
||||
return ref_scalar
|
||||
dtype = datatypes.library_type(scalar.dtype)
|
||||
if dtype != ref_dtype:
|
||||
raise Exception(
|
||||
f"Tensor {name} with type {dtype} does not match expected type {ref_dtype}."
|
||||
)
|
||||
return scalar
|
||||
|
||||
def run(self, A=None, B=None, C=None, D=None,
|
||||
alpha=None, beta=None, batch_count: int = 1,
|
||||
sync: bool = True, print_module: bool = False) -> GemmArguments:
|
||||
|
||||
@@ -168,7 +168,7 @@ class GroupedGemm(Gemm):
|
||||
alignment_B = check.alignment_or_default(alignment_B, alignment_preference)
|
||||
alignment_C = check.alignment_or_default(alignment_C, alignment_preference)
|
||||
|
||||
self._reset_epilogue_functor_alignment(alignment_C)
|
||||
self.epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, self.epilogue_functor)
|
||||
|
||||
tensor_A = TensorDescription(
|
||||
datatypes.binding_type(self._element_a),
|
||||
|
||||
@@ -36,11 +36,13 @@ Base operation used for defining high-level CUTLASS operations (e.g., GEMM, Conv
|
||||
|
||||
from bisect import bisect_left
|
||||
|
||||
from cutlass import option_registry
|
||||
import cutlass
|
||||
from cutlass import option_registry, epilogue
|
||||
from cutlass.backend.utils.device import device_cc
|
||||
from cutlass.epilogue import get_activations
|
||||
from cutlass.library_defaults import _generator_ccs
|
||||
from cutlass.swizzle import get_swizzling_functors
|
||||
from cutlass.utils import datatypes
|
||||
|
||||
|
||||
class OperationBase:
|
||||
@@ -48,22 +50,26 @@ class OperationBase:
|
||||
Base operation used for defining high-level CUTLASS operations (e.g., GEMM, Conv2d)
|
||||
"""
|
||||
|
||||
def __init__(self, cc: int = None, kernel_cc: int = None):
|
||||
def __init__(self, cc: int = None, kernel_cc: int = None, operation_kind = cutlass.OperationKind.Gemm):
|
||||
"""
|
||||
:param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90
|
||||
:type cc: int
|
||||
:param kernel_cc: compute capability of kernels to generate. For example, if running on SM90, but desiring to use a CUTLASS 2.x-style Ampere kernel, this should be set to 80
|
||||
:type kernel_cc: int
|
||||
"""
|
||||
self.operation_kind = operation_kind
|
||||
self.cc = cc if cc is not None else device_cc()
|
||||
self.specified_kernel_cc = kernel_cc is not None
|
||||
self.current_cc = kernel_cc if kernel_cc is not None else self._find_closest_cc(self.cc)
|
||||
self.tile_description = None
|
||||
|
||||
self.options = option_registry.options_for_cc(self.current_cc)
|
||||
self.options = option_registry.options_for_cc(self.current_cc, operation_kind)
|
||||
|
||||
if self.options is None:
|
||||
raise Exception(f"Invalid or unsupported compute capability: {self.current_cc}")
|
||||
|
||||
# Default activation function: identity
|
||||
self._activation = epilogue.identity
|
||||
|
||||
def _find_closest_cc(self, cc: int) -> int:
|
||||
"""
|
||||
@@ -113,4 +119,210 @@ class OperationBase:
|
||||
if cc not in _generator_ccs:
|
||||
raise Exception(f'Invalid CC for CUTLASS kernels: {cc}.')
|
||||
self.current_cc = cc
|
||||
self.options = option_registry.options_for_cc(self.current_cc)
|
||||
self.options = option_registry.options_for_cc(self.current_cc, self.operation_kind)
|
||||
|
||||
def _verify_scalar(self, scalar, ref_scalar, ref_dtype, name):
|
||||
"""
|
||||
Verifies the following properties:
|
||||
1) Either ``scalar`` or ``ref_scakar`` must be set (i.e., not ``None``)
|
||||
2) If ``scalar`` is not ``None``, its datatype must match matches the current version
|
||||
set by the plan (i.e., those in ``ref_dtype``)
|
||||
|
||||
If either of these properties does not hold, an exception is raised. If these properties hold and
|
||||
``scalar`` is not ``None``, ``scalar`` is returned. Otherwise, ``ref_scalar`` is returned.
|
||||
|
||||
:param scalar: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
|
||||
:type scalar: numpy/cupy/torch scalar
|
||||
:param ref_scalar: object representing a tensor passed in on construction of this object, or ``None`` if no tensor was passed in
|
||||
:type ref_scalar: numpy/cupy/torch scalar
|
||||
:param ref_dtype: data type for the scalar that this object was initialized to
|
||||
:param name: identifier of the scalar to verify. Used in raising exceptions
|
||||
:type name: str
|
||||
|
||||
:return: valid scalar to use
|
||||
:rtype: numpy/cupy/torch scalar
|
||||
"""
|
||||
if scalar is None:
|
||||
if ref_scalar is None:
|
||||
raise Exception(f"Scalar {name} must be set.")
|
||||
return ref_scalar
|
||||
if hasattr(scalar, "dtype"):
|
||||
dtype = datatypes.library_type(scalar.dtype)
|
||||
if dtype != ref_dtype:
|
||||
raise Exception(
|
||||
f"Tensor {name} with type {dtype} does not match expected type {ref_dtype}."
|
||||
)
|
||||
return scalar
|
||||
|
||||
def _verify_tensor(self, tensor, ref_tensor, ref_dtype, ref_layout, name):
|
||||
"""
|
||||
Verifies the following properties:
|
||||
1) Either ``tensor`` or ``ref_tensor`` must be set (i.e., not ``None``)
|
||||
2) If ``tensor`` is not ``None``, its datatype and layout must match matches the current versions
|
||||
set by the plan (i.e., those in ``ref_dtype`` and ``ref_layout``)
|
||||
|
||||
If either of these properties does not hold, an exception is raised. If these properties hold and
|
||||
``tensor`` is not ``None``, ``tensor`` is returned. Otherwise, ``ref_tensor`` is returned.
|
||||
|
||||
:param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
|
||||
:type tensor: numpy/cupy/torch array/tensor object
|
||||
:param ref_tensor: object representing a tensor passed in on construction of this object, or ``None`` if no tensor was passed in
|
||||
:type ref_tensor: numpy/cupy/torch array/tensor object
|
||||
:param ref_dtype: data type for the tensor that this object was initialized to
|
||||
:param ref_layout: layout for the tensor that this object was initialized to
|
||||
:param name: identifier of the tensor to verify. Used in raising exceptions
|
||||
:type name: str
|
||||
|
||||
:return: valid tensor object to use
|
||||
:rtype: numpy/cupy/torch array/tensor object
|
||||
"""
|
||||
if tensor is None:
|
||||
if ref_tensor is None:
|
||||
raise Exception(f"Tensor {name} must be set.")
|
||||
return ref_tensor
|
||||
|
||||
self._verify_type_and_layout(tensor, ref_dtype, ref_layout, name)
|
||||
return tensor
|
||||
|
||||
#
|
||||
# Opcode Related
|
||||
#
|
||||
|
||||
@property
|
||||
def opclass(self) -> cutlass.OpcodeClass:
|
||||
"""
|
||||
Returns the opcode class currently in use by the GEMM
|
||||
|
||||
:return: opcode class currently in use
|
||||
:rtype: cutlass.OpcodeClass
|
||||
"""
|
||||
return self.op_class
|
||||
|
||||
@opclass.setter
|
||||
def opclass(self, oc: cutlass.OpcodeClass):
|
||||
if isinstance(oc, str):
|
||||
oc = datatypes.getattr_enum(cutlass.OpcodeClass, oc)
|
||||
if oc in self.possible_op_classes:
|
||||
self.op_class = oc
|
||||
else:
|
||||
raise Exception(
|
||||
f'Unsupported operation class {oc} for CC {self.cc} and data type combination '
|
||||
f'({self._element_a}, {self._element_b}, {self._element_accumulator}) and '
|
||||
f'layout combination ({self._layout_a}, {self._layout_b}).')
|
||||
|
||||
# Changing the op class changes the elements per access in the epilogue. Reset this.
|
||||
if self.op_class == cutlass.OpcodeClass.Simt:
|
||||
elements_per_access = 1
|
||||
else:
|
||||
elements_per_access = 128 // cutlass.DataTypeSize[self._element_c]
|
||||
|
||||
if self.epilogue_functor is not None:
|
||||
self.epilogue_functor = self._reset_epilogue_functor_alignment(elements_per_access, self.epilogue_functor)
|
||||
|
||||
# Changing the op class also changes the possible operations available. Reset these.
|
||||
self.possible_operations = self.options.operations(
|
||||
self.op_class, self._element_a, self._element_b,
|
||||
self._element_accumulator, self._layout_a, self._layout_b)
|
||||
|
||||
#
|
||||
# Epilogue
|
||||
#
|
||||
|
||||
def _create_epilogue_functor_activation(self, activation):
|
||||
"""
|
||||
Returns the epilogue functor with given activation function
|
||||
"""
|
||||
if self.epilogue_functor is None:
|
||||
if self.op_class == cutlass.OpcodeClass.Simt:
|
||||
elements_per_access = 1
|
||||
else:
|
||||
elements_per_access = 128 // cutlass.DataTypeSize[self._element_c]
|
||||
else:
|
||||
elements_per_access = self.epilogue_functor.epilogue_vector_length
|
||||
|
||||
if not self.specified_kernel_cc:
|
||||
if self.current_cc == 90 and activation != epilogue.identity:
|
||||
# CUTLASS 3.0 kernels currently only support identity activation. If one requests a non-identity activation,
|
||||
# revert to using a CUTLASS 2.x kernel by using SM80-tagged kernels.
|
||||
cutlass.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.")
|
||||
self._reset_options(80)
|
||||
self._reset_operations(reset_epilogue=False)
|
||||
elif (self.cc == 90 and self.current_cc != 90 and activation == epilogue.identity):
|
||||
# SM80 fallback kernels are currently used. Since an identity activation is requested,
|
||||
# we can switch back to using SM90 kernels.
|
||||
self._reset_options(90)
|
||||
self._reset_operations(reset_epilogue=False)
|
||||
else:
|
||||
if self.current_cc == 90 and activation != epilogue.identity:
|
||||
raise Exception("Epilogues with elementwise fusion are not currently supported "
|
||||
"in the Python interface for 3.x kernels. To use 2.x kernels "
|
||||
"with fused elementwise epilogues, do not set the `kernel_cc` "
|
||||
"parameter when constructing the Gemm object.")
|
||||
|
||||
return epilogue.get_activation_epilogue(
|
||||
activation,
|
||||
datatypes.binding_type(self._element_c),
|
||||
elements_per_access,
|
||||
datatypes.binding_type(self._element_accumulator),
|
||||
datatypes.binding_type(self._element_accumulator),
|
||||
)
|
||||
|
||||
def _reset_epilogue_functor_activation(self, activation):
|
||||
"""
|
||||
Set the epilogue functor based on the provided activation function
|
||||
"""
|
||||
self.epilogue_functor = self._create_epilogue_functor_activation(activation)
|
||||
|
||||
def _reset_epilogue_functor_alignment(self, alignment, epilogue_functor):
|
||||
"""
|
||||
Reset the alignment of the current epilogue functor based on alignment C
|
||||
"""
|
||||
if epilogue_functor is None or not hasattr(epilogue_functor, 'activation_functor'):
|
||||
# Identity epilogue does not have 'activation_functor'
|
||||
activation = epilogue.identity
|
||||
else:
|
||||
activation = type(epilogue_functor.activation_functor)
|
||||
|
||||
epilogue_functor = epilogue.get_activation_epilogue(
|
||||
activation,
|
||||
datatypes.binding_type(self._element_c),
|
||||
alignment,
|
||||
datatypes.binding_type(self._element_accumulator),
|
||||
datatypes.binding_type(self._element_accumulator),
|
||||
)
|
||||
return epilogue_functor
|
||||
|
||||
@property
|
||||
def activation(self):
|
||||
"""
|
||||
Returns the type of the current activation function used
|
||||
"""
|
||||
if hasattr(self.epilogue_functor, "activation_functor"):
|
||||
return type(self.epilogue_functor.activation_functor)
|
||||
else:
|
||||
return epilogue.identity
|
||||
|
||||
@activation.setter
|
||||
def activation(self, act):
|
||||
"""
|
||||
Sets the type of the activation function to use
|
||||
Activation can come with a set of arguments
|
||||
|
||||
:param act: type of activation function to use
|
||||
:type act: str or tuple. e.g. "relu", ("leaky_relu", 0.01)
|
||||
|
||||
"""
|
||||
if isinstance(act, tuple):
|
||||
if isinstance(act[0], str):
|
||||
act_fn = getattr(cutlass.backend.epilogue, act[0])
|
||||
else:
|
||||
act_fn = act[0]
|
||||
self._reset_epilogue_functor_activation(act_fn)
|
||||
self._activation_args = act[1]
|
||||
self._activation = act[0]
|
||||
else:
|
||||
if isinstance(act, str):
|
||||
act = getattr(cutlass.backend.epilogue, act)
|
||||
self._reset_epilogue_functor_activation(act)
|
||||
self._activation = act
|
||||
|
||||
|
||||
@@ -32,9 +32,10 @@
|
||||
|
||||
from cutlass.utils.check import (
|
||||
alignment_or_default,
|
||||
update_alignment,
|
||||
calculate_smem_usage,
|
||||
calculate_smem_usage_per_stage,
|
||||
valid_cluster_shape,
|
||||
valid_kernel_schedule,
|
||||
valid_schedule,
|
||||
valid_stage_count,
|
||||
)
|
||||
|
||||
@@ -39,29 +39,35 @@ import ctypes
|
||||
import cutlass_bindings
|
||||
import cutlass
|
||||
from cutlass.backend.library import DataTypeSize, TileDescription
|
||||
from cutlass.utils.datatypes import binding_type
|
||||
|
||||
|
||||
def calculate_smem_usage_per_stage(tile_description, operation_kind):
|
||||
def calculate_smem_usage_per_stage(td: TileDescription, operation_kind: cutlass.OperationKind) -> int:
|
||||
"""
|
||||
Returns the amount of shared memory in bytes consumed in a single stage of a kernel.
|
||||
|
||||
:param td: tile description to compute shared memory of
|
||||
:type td: TileDescription
|
||||
:param operation_kind: identifier for the type of operation being performed
|
||||
:type operation_kind: cutlass.OperationKind
|
||||
|
||||
:return: number of bytes of shared memory consumed by a single stage
|
||||
:rtype: int
|
||||
"""
|
||||
m, n, k = tile_description.threadblock_shape
|
||||
m, n, k = td.threadblock_shape
|
||||
|
||||
if operation_kind == cutlass.OperationKind.Gemm:
|
||||
stage_barrier_bytes = 32
|
||||
return (
|
||||
(DataTypeSize[tile_description.math_instruction.element_a] * m * k // 8)
|
||||
+ (DataTypeSize[tile_description.math_instruction.element_b] * k * n // 8)
|
||||
(DataTypeSize[td.math_instruction.element_a] * m * k // 8)
|
||||
+ (DataTypeSize[td.math_instruction.element_b] * k * n // 8)
|
||||
+ stage_barrier_bytes
|
||||
)
|
||||
else:
|
||||
raise Exception(f"No available shared memory calculation for operation kind {operation.operation_kind}")
|
||||
|
||||
|
||||
def calculate_smem_usage(operation):
|
||||
def calculate_smem_usage(operation) -> int:
|
||||
"""
|
||||
Returns the amount of shared memory in bytes consumed by a kernel.
|
||||
|
||||
@@ -72,7 +78,11 @@ def calculate_smem_usage(operation):
|
||||
return _per_stage * operation.tile_description.stages
|
||||
|
||||
|
||||
def valid_stage_count(cc: int, td: TileDescription) -> tuple:
|
||||
def valid_stage_count(
|
||||
cc: int,
|
||||
td: TileDescription,
|
||||
element_C: cutlass.DataType = None,
|
||||
element_D: cutlass.DataType = None) -> tuple:
|
||||
"""
|
||||
Checks whether a device with `cc` supports the number of stages within `tile_description`, both
|
||||
based on raw limits on the number of stages and based on shared memory capacity
|
||||
@@ -81,15 +91,26 @@ def valid_stage_count(cc: int, td: TileDescription) -> tuple:
|
||||
:type cc: int
|
||||
:param td: tile description to check
|
||||
:type td: TileDescription
|
||||
:param element_C: data type of operand C
|
||||
:type element_C: cutlass.DataType
|
||||
:param element_D: data type of operand D
|
||||
:type element_D: cutlass.DataType
|
||||
|
||||
:return: tuple with the first element indicating whether the provided tile description is
|
||||
valid for the provided device and the second element being an error message
|
||||
:rtype: tuple
|
||||
"""
|
||||
if cc == 90 and (td.stages is None or td.stages == 0):
|
||||
# Stage count of None or 0 for SM90 indicates that the CollectiveBuilder automatically
|
||||
# determines the stage count to use. Thus, all settings are valid in these scenarios.
|
||||
return (True, "")
|
||||
if cc == 90:
|
||||
if (td.stages is None or td.stages == 0):
|
||||
# Stage count of None or 0 for SM90 indicates that the CollectiveBuilder automatically
|
||||
# determines the stage count to use. Thus, all settings are valid in these scenarios.
|
||||
return (True, "")
|
||||
else:
|
||||
cutlass.logger.warning(
|
||||
"Setting an explicit stage count for SM90 kernels currently may "
|
||||
"result in compilation errors if the combination of tile shape, "
|
||||
"stage count, and shared memory requirement of the epilogue exceeds "
|
||||
"the available shared memory per SM.")
|
||||
|
||||
if td.stages <= 0:
|
||||
return (False, f"Stage counts must be positive integers. Tile description has stage count of {td.stages}.")
|
||||
@@ -98,14 +119,20 @@ def valid_stage_count(cc: int, td: TileDescription) -> tuple:
|
||||
return (False, f"Tile description has stage count of {td.stages}, "
|
||||
f"but only 2 stages are supported on SM{cc}.")
|
||||
|
||||
# The calculation below does not consider shared memory used by the epilogue and, thus,
|
||||
# only catches cases in which the mainloop exceeds the device's shared memory capacity.
|
||||
# This is not a concern for CUTLASS 2.x kernels, for which the shared memory of the
|
||||
# mainloop and epilogue is shared.
|
||||
smem_per_stage = calculate_smem_usage_per_stage(td, cutlass.OperationKind.Gemm)
|
||||
smem_usage_mainloop = (smem_per_stage * td.stages)
|
||||
smem_arch = cutlass.SharedMemPerCC[cc] << 10
|
||||
if (smem_per_stage * td.stages) > smem_arch:
|
||||
if smem_usage_mainloop > smem_arch:
|
||||
return ( False,
|
||||
"Configuration uses too much shared memory. Consider reducing stage count or tile shape.\n"
|
||||
f"Details: configuration uses {smem_per_stage} bytes of shared memory per stage, and "
|
||||
f"{td.stages} stages for a total of {smem_per_stage * td.stages} bytes.\n"
|
||||
f"The maxmium amoung of shared memory that can be used per block on CC {cc} is {smem_arch}.")
|
||||
f"Details:\n"
|
||||
f"Mainloop uses {smem_per_stage} bytes of shared memory per stage, and "
|
||||
f"{td.stages} stages for a total of {smem_usage_mainloop} bytes.\n"
|
||||
f"The maxmium amount of shared memory that can be used per block on CC {cc} is {smem_arch}.")
|
||||
|
||||
return (True, "")
|
||||
|
||||
@@ -153,21 +180,40 @@ def valid_cluster_shape(cc: int, cluster_shape: list) -> tuple:
|
||||
return (True, "")
|
||||
|
||||
|
||||
def valid_kernel_schedule(cc: int, kernel_schedule: cutlass.KernelScheduleType) -> tuple:
|
||||
def valid_schedule(
|
||||
cc: int,
|
||||
kernel_schedule: cutlass.KernelScheduleType,
|
||||
epilogue_schedule: cutlass.EpilogueScheduleType,
|
||||
tile_scheduler: cutlass.TileSchedulerType) -> tuple:
|
||||
"""
|
||||
Checks whether a device with ``cc`` supports ``kernel_schedule``.
|
||||
Checks that the kernel and epilogue schedules passed in are a valid combination for
|
||||
a device of compute capability ``cc``.
|
||||
|
||||
:param cc: compute capability of device in question
|
||||
:type cc: int
|
||||
:param kernel_schedule: kernel schedule type
|
||||
:type KernelScheduleType: cutlass.KernelScheduleType
|
||||
:type kernel_schedule: cutlass.KernelScheduleType
|
||||
:param epilogue_schedule: epilogue schedule type
|
||||
:type epilogue_schedule: cutlass.EpilogueScheduleType
|
||||
:param tile_scheduler: tile scheduler type
|
||||
:type tile_scheduler: cutlass.TileSchedulerType
|
||||
|
||||
:return: tuple with the first element indicating whether the provided kernel schedule is
|
||||
:return: tuple with the first element indicating whether the provided schedules are
|
||||
valid for the provided device and the second element being an error message
|
||||
:rtype: tuple
|
||||
"""
|
||||
if kernel_schedule != cutlass.KernelScheduleType.ScheduleAuto and cc < 90:
|
||||
return (False, "Non-default kernel schedules are only supported on SM90 and beyond")
|
||||
kernel_auto = (kernel_schedule == cutlass.KernelScheduleType.ScheduleAuto)
|
||||
epilogue_auto = (epilogue_schedule == cutlass.EpilogueScheduleType.ScheduleAuto)
|
||||
tile_scheduler_default = (tile_scheduler == cutlass.TileSchedulerType.Default)
|
||||
if cc < 90 and not (kernel_auto and epilogue_auto and tile_scheduler_default):
|
||||
return (False, "Non-default schedules are only supported on SM90 and beyond")
|
||||
|
||||
if (kernel_auto and not epilogue_auto) or (not kernel_auto and epilogue_auto):
|
||||
return (False, "Kernel and epilogue schedules must either both be auto or neither be auto")
|
||||
|
||||
if not tile_scheduler_default:
|
||||
if (tile_scheduler == cutlass.TileSchedulerType.StreamK) and (kernel_schedule != cutlass.KernelScheduleType.TmaWarpSpecializedCooperative):
|
||||
return (False, "Stream-K tile scheduler is currently only supported with the cooperative kernel schedule")
|
||||
return (True, "")
|
||||
|
||||
|
||||
@@ -190,3 +236,26 @@ def alignment_or_default(alignment_provided: int, default_alignment: int) -> int
|
||||
return alignment_provided
|
||||
|
||||
return default_alignment
|
||||
|
||||
|
||||
def update_alignment(alignment_provided:int, default_alignment: int) -> int:
|
||||
"""
|
||||
Returns `alignment_provided` if it is set, otherwise `default_alignment` and checks
|
||||
that `alignment_provided` does not exceed `default_alignment`.
|
||||
|
||||
:param alignment_provided: alignment preference specified. Can be None.
|
||||
:type alignment_provided: int
|
||||
:param default_alignment: alignment to use if `alignment_provided` is None
|
||||
:type default_alignment: int
|
||||
|
||||
:return: alignment to use
|
||||
:rtype: int
|
||||
"""
|
||||
if alignment_provided is not None:
|
||||
if alignment_provided > default_alignment:
|
||||
if alignment_provided % default_alignment == 0:
|
||||
return default_alignment
|
||||
raise Exception(f"Alignment {alignment_provided} exceeds the maximum supported of {default_alignment}.")
|
||||
return alignment_provided
|
||||
|
||||
return default_alignment
|
||||
|
||||
@@ -232,6 +232,8 @@ def library_layout(layout):
|
||||
return cutlass.LayoutType.RowMajor
|
||||
elif layout == cutlass_bindings.ColumnMajor:
|
||||
return cutlass.LayoutType.ColumnMajor
|
||||
elif layout == cutlass_bindings.TensorNHWC:
|
||||
return cutlass.LayoutType.TensorNHWC
|
||||
else:
|
||||
raise Exception(f"No conversion available for layout {layout} to library layout.")
|
||||
|
||||
@@ -251,6 +253,8 @@ def binding_layout(layout):
|
||||
return cutlass_bindings.RowMajor
|
||||
elif layout == cutlass.LayoutType.ColumnMajor:
|
||||
return cutlass_bindings.ColumnMajor
|
||||
elif layout == cutlass.LayoutType.TensorNHWC:
|
||||
return cutlass_bindings.TensorNHWC
|
||||
else:
|
||||
raise Exception(f"No conversion available for layout {layout} to Python-bound CUTLASS layout.")
|
||||
|
||||
@@ -279,6 +283,16 @@ def get_datatype_and_layout(tensor):
|
||||
else:
|
||||
raise Exception(f"Unable to convert tensor of type {type(tensor)} to Python-bound CUTLASS datatype and layout.")
|
||||
|
||||
def get_tensor_shape(tensor):
|
||||
if (numpy_available and isinstance(tensor, np.ndarray)) or (
|
||||
cupy_available and isinstance(tensor, cp.ndarray)
|
||||
):
|
||||
return tensor.shape
|
||||
elif torch_available and isinstance(tensor, torch.Tensor):
|
||||
size = tensor.size()
|
||||
return (size[0], size[2], size[3], size[1])
|
||||
else:
|
||||
raise Exception(f"Unable to convert tensor of type {type(tensor)} to Python-bound CUTLASS datatype and layout.")
|
||||
|
||||
def binding_opclass(opclass: cutlass.OpcodeClass):
|
||||
if opclass == cutlass.OpcodeClass.TensorOp:
|
||||
@@ -299,7 +313,9 @@ def backend_math_operation(math_op: cutlass.MathOperation):
|
||||
|
||||
|
||||
def construct_backend_td(td: cutlass.TileDescription,
|
||||
kernel_schedule: cutlass.KernelScheduleType) -> TileDescription:
|
||||
kernel_schedule: cutlass.KernelScheduleType,
|
||||
epilogue_schedule: cutlass.EpilogueScheduleType,
|
||||
tile_scheduler: cutlass.TileSchedulerType) -> TileDescription:
|
||||
mi = td.math_instruction
|
||||
backend_mi = MathInstruction(
|
||||
mi.instruction_shape,
|
||||
@@ -309,8 +325,9 @@ def construct_backend_td(td: cutlass.TileDescription,
|
||||
binding_opclass(mi.opcode_class),
|
||||
backend_math_operation(mi.math_operation)
|
||||
)
|
||||
cluster_shape = td.cluster_shape if hasattr(td, "cluster_shape") else [1, 1, 1]
|
||||
return TileDescription(td.threadblock_shape, td.stages, td.warp_count,
|
||||
backend_mi, td.cluster_shape, kernel_schedule)
|
||||
backend_mi, cluster_shape, kernel_schedule, epilogue_schedule, tile_scheduler)
|
||||
|
||||
|
||||
def td_from_profiler_op(op) -> TileDescription:
|
||||
@@ -322,8 +339,10 @@ def td_from_profiler_op(op) -> TileDescription:
|
||||
:returns: backend TileDescription
|
||||
:rtype: cutlass.backend.TileDescription
|
||||
"""
|
||||
schedule = op.kernel_schedule if hasattr(op, 'kernel_schedule') else None
|
||||
return construct_backend_td(op.tile_description, schedule)
|
||||
kschedule = op.kernel_schedule if hasattr(op, 'kernel_schedule') else None
|
||||
eschedule = op.epilogue_schedule if hasattr(op, 'epilogue_schedule') else None
|
||||
tschedule = op.tile_scheduler if hasattr(op, 'tile_scheduler') else None
|
||||
return construct_backend_td(op.tile_description, kschedule, eschedule, tschedule)
|
||||
|
||||
|
||||
def td_from_profiler_td(td: cutlass.backend.TileDescription) -> TileDescription:
|
||||
@@ -336,4 +355,16 @@ def td_from_profiler_td(td: cutlass.backend.TileDescription) -> TileDescription:
|
||||
:returns: backend TileDescription
|
||||
:rtype: cutlass.backend.TileDescription
|
||||
"""
|
||||
return construct_backend_td(td, kernel_schedule=None)
|
||||
return construct_backend_td(td, kernel_schedule=None, epilogue_schedule=None, tile_scheduler=None)
|
||||
|
||||
def to_camel_case(snake_str):
|
||||
return "".join(x.capitalize() for x in snake_str.lower().split("_"))
|
||||
|
||||
|
||||
def getattr_enum(obj, attr_name):
|
||||
# The attr_name is under the snake_case
|
||||
camel_attr = to_camel_case(attr_name)
|
||||
if hasattr(obj, camel_attr):
|
||||
return getattr(obj, camel_attr)
|
||||
else:
|
||||
raise Exception(f"Invalid option: {attr_name}")
|
||||
|
||||
@@ -112,6 +112,7 @@ library_dirs = [
|
||||
cuda_install_path + '/lib64',
|
||||
]
|
||||
|
||||
|
||||
ext_modules = [
|
||||
Pybind11Extension('cutlass_bindings',
|
||||
['cutlass/cpp/cutlass_bindings.cpp'],
|
||||
|
||||
Reference in New Issue
Block a user