mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-12 09:15:56 +00:00
CUTLASS 3.2.1 (#1113)
* Updates for 3.2.1 release. * Minor fix in gemm op profiler for raster order. * Add scheduler mapping for raster order in the kernels.
This commit is contained in:
@@ -32,10 +32,10 @@
|
||||
|
||||
from cutlass.utils.check import (
|
||||
alignment_or_default,
|
||||
update_alignment,
|
||||
calculate_smem_usage,
|
||||
calculate_smem_usage_per_stage,
|
||||
valid_cluster_shape,
|
||||
valid_schedule,
|
||||
valid_stage_count,
|
||||
update_alignment,
|
||||
)
|
||||
|
||||
@@ -36,10 +36,9 @@ Utility functions for checking constraints on kernels and calculating kernel att
|
||||
|
||||
import ctypes
|
||||
|
||||
import cutlass_bindings
|
||||
import cutlass
|
||||
from cutlass.backend.library import DataTypeSize, TileDescription
|
||||
from cutlass.utils.datatypes import binding_type
|
||||
from cutlass import DataTypeSize
|
||||
from cutlass.backend.library import TileDescription
|
||||
|
||||
|
||||
def calculate_smem_usage_per_stage(td: TileDescription, operation_kind: cutlass.OperationKind) -> int:
|
||||
@@ -80,6 +79,7 @@ def calculate_smem_usage(operation) -> int:
|
||||
|
||||
def valid_stage_count(
|
||||
cc: int,
|
||||
kernel_cc: int,
|
||||
td: TileDescription,
|
||||
element_C: cutlass.DataType = None,
|
||||
element_D: cutlass.DataType = None) -> tuple:
|
||||
@@ -89,6 +89,8 @@ def valid_stage_count(
|
||||
|
||||
:param cc: compute capability of device in question
|
||||
:type cc: int
|
||||
:param kernel_cc: compute capability that the kernel targets (corresponding to the arch::SMxy tag in CUTLASS)
|
||||
:type kernel_cc: int
|
||||
:param td: tile description to check
|
||||
:type td: TileDescription
|
||||
:param element_C: data type of operand C
|
||||
@@ -100,7 +102,7 @@ def valid_stage_count(
|
||||
valid for the provided device and the second element being an error message
|
||||
:rtype: tuple
|
||||
"""
|
||||
if cc == 90:
|
||||
if kernel_cc == 90:
|
||||
if (td.stages is None or td.stages == 0):
|
||||
# Stage count of None or 0 for SM90 indicates that the CollectiveBuilder automatically
|
||||
# determines the stage count to use. Thus, all settings are valid in these scenarios.
|
||||
|
||||
@@ -34,14 +34,13 @@
|
||||
Utility functions for converting between frontend datatypes and CUTLASS datatypes
|
||||
"""
|
||||
|
||||
import cutlass_bindings
|
||||
|
||||
import cutlass
|
||||
from cutlass.backend.library import (
|
||||
from cutlass import (
|
||||
DataTypeSize,
|
||||
)
|
||||
from cutlass.backend.library import (
|
||||
MathInstruction,
|
||||
MathOperation,
|
||||
ShortLayoutTypeNames,
|
||||
TileDescription,
|
||||
)
|
||||
|
||||
@@ -123,6 +122,9 @@ try:
|
||||
torch.float32: cutlass.DataType.f32,
|
||||
torch.double: cutlass.DataType.f64,
|
||||
torch.float64: cutlass.DataType.f64,
|
||||
torch.int8: cutlass.DataType.s8,
|
||||
torch.int32: cutlass.DataType.s32,
|
||||
torch.uint8: cutlass.DataType.u8,
|
||||
}
|
||||
|
||||
_library_to_torch_dict = {
|
||||
@@ -133,6 +135,9 @@ try:
|
||||
cutlass.DataType.f32: torch.float32,
|
||||
cutlass.DataType.f64: torch.double,
|
||||
cutlass.DataType.f64: torch.float64,
|
||||
cutlass.DataType.s8: torch.int8,
|
||||
cutlass.DataType.s32: torch.int32,
|
||||
cutlass.DataType.u8: torch.uint8,
|
||||
}
|
||||
except ImportError:
|
||||
torch_available = False
|
||||
@@ -162,51 +167,12 @@ def bfloat16_library_type(inp) -> cutlass.DataType:
|
||||
return cutlass.DataType.bf16
|
||||
|
||||
|
||||
def bfloat16_type(inp) -> bfloat16.bfloat16:
|
||||
def bfloat16_type(inp):
|
||||
if bfloat16_available:
|
||||
if inp == cutlass.DataType.bf16:
|
||||
return bfloat16.bfloat16
|
||||
|
||||
|
||||
# Mapping from library data type to Python-bound CUTLASS data type
|
||||
library_to_binding_dict = {
|
||||
cutlass.DataType.s8: cutlass_bindings.int8,
|
||||
cutlass.DataType.s32: cutlass_bindings.int32,
|
||||
cutlass.DataType.f16: cutlass_bindings.float16,
|
||||
cutlass.DataType.bf16: cutlass_bindings.bfloat16,
|
||||
cutlass.DataType.f32: cutlass_bindings.float32,
|
||||
cutlass.DataType.f64: cutlass_bindings.float64,
|
||||
cutlass.DataType.tf32: cutlass_bindings.tfloat32,
|
||||
}
|
||||
|
||||
# Mapping from Python-bound CUTLASS data type to library data type
|
||||
binding_to_library = {
|
||||
cutlass_bindings.int8: cutlass.DataType.s8,
|
||||
cutlass_bindings.int32: cutlass.DataType.s32,
|
||||
cutlass_bindings.float16: cutlass.DataType.f16,
|
||||
cutlass_bindings.bfloat16: cutlass.DataType.bf16,
|
||||
cutlass_bindings.float32: cutlass.DataType.f32,
|
||||
cutlass_bindings.float64: cutlass.DataType.f64,
|
||||
cutlass_bindings.tfloat32: cutlass.DataType.tf32,
|
||||
}
|
||||
|
||||
|
||||
def binding_library_type(inp):
|
||||
if inp in binding_to_library:
|
||||
return binding_to_library[inp]
|
||||
return None
|
||||
|
||||
|
||||
def has_binding_type(inp: cutlass.DataType):
|
||||
return inp in library_to_binding_dict
|
||||
|
||||
|
||||
def library_to_binding(inp: cutlass.DataType):
|
||||
if not has_binding_type(inp):
|
||||
raise Exception(f"No available conversion from library type {inp} to Python-bound CUTLASS type")
|
||||
return library_to_binding_dict[inp]
|
||||
|
||||
|
||||
def library_type(inp):
|
||||
if inp in cutlass.DataTypeSize.keys():
|
||||
return inp
|
||||
@@ -216,7 +182,6 @@ def library_type(inp):
|
||||
cupy_library_type,
|
||||
numpy_library_type,
|
||||
torch_library_type,
|
||||
binding_library_type,
|
||||
]:
|
||||
out = cvt_fn(inp)
|
||||
if out is not None:
|
||||
@@ -225,42 +190,6 @@ def library_type(inp):
|
||||
raise Exception(f"No available conversion from type {inp} to a library type.")
|
||||
|
||||
|
||||
def library_layout(layout):
|
||||
if layout in cutlass.LayoutTag.keys():
|
||||
return layout
|
||||
|
||||
# Convert Python-bound CUTLASS layout to profiler library layout
|
||||
if layout == cutlass_bindings.RowMajor:
|
||||
return cutlass.LayoutType.RowMajor
|
||||
elif layout == cutlass_bindings.ColumnMajor:
|
||||
return cutlass.LayoutType.ColumnMajor
|
||||
elif layout == cutlass_bindings.TensorNHWC:
|
||||
return cutlass.LayoutType.TensorNHWC
|
||||
else:
|
||||
raise Exception(f"No conversion available for layout {layout} to library layout.")
|
||||
|
||||
|
||||
def binding_type(inp):
|
||||
if inp in DataTypeSize.keys():
|
||||
return inp
|
||||
|
||||
libtype = library_type(inp)
|
||||
return library_to_binding(libtype)
|
||||
|
||||
|
||||
def binding_layout(layout):
|
||||
if layout in ShortLayoutTypeNames.keys():
|
||||
return layout
|
||||
elif layout == cutlass.LayoutType.RowMajor:
|
||||
return cutlass_bindings.RowMajor
|
||||
elif layout == cutlass.LayoutType.ColumnMajor:
|
||||
return cutlass_bindings.ColumnMajor
|
||||
elif layout == cutlass.LayoutType.TensorNHWC:
|
||||
return cutlass_bindings.TensorNHWC
|
||||
else:
|
||||
raise Exception(f"No conversion available for layout {layout} to Python-bound CUTLASS layout.")
|
||||
|
||||
|
||||
def _tensor_from_numpy(np_tensor):
|
||||
dtype = library_type(np_tensor.dtype)
|
||||
if np_tensor.flags.c_contiguous:
|
||||
@@ -282,28 +211,28 @@ def get_datatype_and_layout(tensor):
|
||||
return _tensor_from_numpy(tensor)
|
||||
elif torch_available and isinstance(tensor, torch.Tensor):
|
||||
return _tensor_from_torch(tensor)
|
||||
elif isinstance(tensor, float) or isinstance(tensor, int):
|
||||
return (cutlass.DataType.f32, cutlass.LayoutType.RowMajor)
|
||||
else:
|
||||
raise Exception(f"Unable to convert tensor of type {type(tensor)} to Python-bound CUTLASS datatype and layout.")
|
||||
|
||||
def get_tensor_shape(tensor):
|
||||
def get_tensor_shape(tensor, op="GEMM"):
|
||||
if (numpy_available and isinstance(tensor, np.ndarray)) or (
|
||||
cupy_available and isinstance(tensor, cp.ndarray)
|
||||
):
|
||||
return tensor.shape
|
||||
elif torch_available and isinstance(tensor, torch.Tensor):
|
||||
size = tensor.size()
|
||||
return (size[0], size[2], size[3], size[1])
|
||||
if op == "CONV":
|
||||
# PyTorch Tensors have shape NCHW
|
||||
return (size[0], size[2], size[3], size[1])
|
||||
else:
|
||||
return tuple(tensor.size())
|
||||
elif isinstance(tensor, float) or isinstance(tensor, int):
|
||||
return (1,)
|
||||
else:
|
||||
raise Exception(f"Unable to convert tensor of type {type(tensor)} to Python-bound CUTLASS datatype and layout.")
|
||||
|
||||
def binding_opclass(opclass: cutlass.OpcodeClass):
|
||||
if opclass == cutlass.OpcodeClass.TensorOp:
|
||||
return cutlass_bindings.OpClass.TensorOp
|
||||
elif opclass == cutlass.OpcodeClass.Simt:
|
||||
return cutlass_bindings.OpClass.Simt
|
||||
else:
|
||||
raise Exception(f"Unable to convert opcode class of type {opclass} to Python-bound CUTLASS opcode class.")
|
||||
|
||||
|
||||
_math_operation_value_map = {x.value: x for x in MathOperation}
|
||||
|
||||
@@ -321,10 +250,10 @@ def construct_backend_td(td: cutlass.TileDescription,
|
||||
mi = td.math_instruction
|
||||
backend_mi = MathInstruction(
|
||||
mi.instruction_shape,
|
||||
binding_type(mi.element_a),
|
||||
binding_type(mi.element_b),
|
||||
binding_type(mi.element_accumulator),
|
||||
binding_opclass(mi.opcode_class),
|
||||
mi.element_a,
|
||||
mi.element_b,
|
||||
mi.element_accumulator,
|
||||
mi.opcode_class,
|
||||
backend_math_operation(mi.math_operation)
|
||||
)
|
||||
cluster_shape = td.cluster_shape if hasattr(td, "cluster_shape") else [1, 1, 1]
|
||||
@@ -347,7 +276,7 @@ def td_from_profiler_op(op) -> TileDescription:
|
||||
return construct_backend_td(op.tile_description, kschedule, eschedule, tschedule)
|
||||
|
||||
|
||||
def td_from_profiler_td(td: cutlass.backend.TileDescription) -> TileDescription:
|
||||
def td_from_profiler_td(td: TileDescription) -> TileDescription:
|
||||
"""
|
||||
Converts the profiler's TileDescription into the backend TileDescription
|
||||
|
||||
@@ -359,6 +288,7 @@ def td_from_profiler_td(td: cutlass.backend.TileDescription) -> TileDescription:
|
||||
"""
|
||||
return construct_backend_td(td, kernel_schedule=None, epilogue_schedule=None, tile_scheduler=None)
|
||||
|
||||
|
||||
def to_camel_case(snake_str):
|
||||
return "".join(x.capitalize() for x in snake_str.lower().split("_"))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user