mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-13 09:45:45 +00:00
40
python/cutlass/utils/__init__.py
Normal file
40
python/cutlass/utils/__init__.py
Normal file
@@ -0,0 +1,40 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
from cutlass.utils.check import (
|
||||
alignment_or_default,
|
||||
calculate_smem_usage,
|
||||
calculate_smem_usage_per_stage,
|
||||
valid_cluster_shape,
|
||||
valid_kernel_schedule,
|
||||
valid_stage_count,
|
||||
)
|
||||
192
python/cutlass/utils/check.py
Normal file
192
python/cutlass/utils/check.py
Normal file
@@ -0,0 +1,192 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Utility functions for checking constraints on kernels and calculating kernel attributes
|
||||
"""
|
||||
|
||||
import ctypes
|
||||
|
||||
import cutlass_bindings
|
||||
import cutlass
|
||||
from cutlass.backend.library import DataTypeSize, TileDescription
|
||||
|
||||
|
||||
def calculate_smem_usage_per_stage(tile_description, operation_kind):
|
||||
"""
|
||||
Returns the amount of shared memory in bytes consumed in a single stage of a kernel.
|
||||
|
||||
:return: number of bytes of shared memory consumed by a single stage
|
||||
:rtype: int
|
||||
"""
|
||||
m, n, k = tile_description.threadblock_shape
|
||||
|
||||
if operation_kind == cutlass.OperationKind.Gemm:
|
||||
stage_barrier_bytes = 32
|
||||
return (
|
||||
(DataTypeSize[tile_description.math_instruction.element_a] * m * k // 8)
|
||||
+ (DataTypeSize[tile_description.math_instruction.element_b] * k * n // 8)
|
||||
+ stage_barrier_bytes
|
||||
)
|
||||
else:
|
||||
raise Exception(f"No available shared memory calculation for operation kind {operation.operation_kind}")
|
||||
|
||||
|
||||
def calculate_smem_usage(operation):
|
||||
"""
|
||||
Returns the amount of shared memory in bytes consumed by a kernel.
|
||||
|
||||
:return: number of bytes of shared memory consumed by the operation
|
||||
:return: int
|
||||
"""
|
||||
_per_stage = calculate_smem_usage_per_stage(operation.tile_description, operation.operation_kind)
|
||||
return _per_stage * operation.tile_description.stages
|
||||
|
||||
|
||||
def valid_stage_count(cc: int, td: TileDescription) -> tuple:
|
||||
"""
|
||||
Checks whether a device with `cc` supports the number of stages within `tile_description`, both
|
||||
based on raw limits on the number of stages and based on shared memory capacity
|
||||
|
||||
:param cc: compute capability of device in question
|
||||
:type cc: int
|
||||
:param td: tile description to check
|
||||
:type td: TileDescription
|
||||
|
||||
:return: tuple with the first element indicating whether the provided tile description is
|
||||
valid for the provided device and the second element being an error message
|
||||
:rtype: tuple
|
||||
"""
|
||||
if cc == 90 and (td.stages is None or td.stages == 0):
|
||||
# Stage count of None or 0 for SM90 indicates that the CollectiveBuilder automatically
|
||||
# determines the stage count to use. Thus, all settings are valid in these scenarios.
|
||||
return (True, "")
|
||||
|
||||
if td.stages <= 0:
|
||||
return (False, f"Stage counts must be positive integers. Tile description has stage count of {td.stages}.")
|
||||
|
||||
if cc < 80 and td.stages != 2:
|
||||
return (False, f"Tile description has stage count of {td.stages}, "
|
||||
f"but only 2 stages are supported on SM{cc}.")
|
||||
|
||||
smem_per_stage = calculate_smem_usage_per_stage(td, cutlass.OperationKind.Gemm)
|
||||
smem_arch = cutlass.SharedMemPerCC[cc] << 10
|
||||
if (smem_per_stage * td.stages) > smem_arch:
|
||||
return ( False,
|
||||
"Configuration uses too much shared memory. Consider reducing stage count or tile shape.\n"
|
||||
f"Details: configuration uses {smem_per_stage} bytes of shared memory per stage, and "
|
||||
f"{td.stages} stages for a total of {smem_per_stage * td.stages} bytes.\n"
|
||||
f"The maxmium amoung of shared memory that can be used per block on CC {cc} is {smem_arch}.")
|
||||
|
||||
return (True, "")
|
||||
|
||||
|
||||
def valid_cluster_shape(cc: int, cluster_shape: list) -> tuple:
|
||||
"""
|
||||
Checks whether a device with `cc` supports a thread block cluster of shape `cluster_shape`.
|
||||
|
||||
:param cc: compute capability of device in question
|
||||
:type cc: int
|
||||
:param cluster_shape: dimensions of thread block cluster shape to check
|
||||
:type cluster_shape: list
|
||||
|
||||
:return: tuple with the first element indicating whether the provided cluster shape is
|
||||
valid for the provided device and the second element being an error message
|
||||
:rtype: tuple
|
||||
"""
|
||||
|
||||
if cc < 90:
|
||||
if cluster_shape != [1, 1, 1]:
|
||||
return (False,
|
||||
f"Cluster shape for pre-SM90 architectures must be [1, 1, 1]. Received cluster shape of "
|
||||
f"{cluster_shape} for SM{cc}.")
|
||||
else:
|
||||
return (True, "")
|
||||
|
||||
if len(cluster_shape) != 3:
|
||||
return (False,
|
||||
f"Cluster shapes must be rank-3. Received {cluster_shape} (rank {len(cluster_shape)}")
|
||||
|
||||
if cluster_shape[2] != 1:
|
||||
return (False,
|
||||
"CUTLASS kernels currently require the third dimension of cluster shape to be 1. "
|
||||
f"Received cluster shape of {cluster_shape}.")
|
||||
|
||||
# The CUDA programming guide currently defines a maximum of 8 thread blocks per cluster
|
||||
# as being portably supported (https://docs.nvidia.com/cuda/cuda-c-programming-guide/#thread-block-clusters).
|
||||
# Current CUTLASS kernels only have non-unit cluster dimensions within the first two dimensions,
|
||||
# so we check that the first two dimensions of the cluster shape do not exceed 8 thread blocks in total.
|
||||
blocks_in_2d = cluster_shape[0] * cluster_shape[1]
|
||||
if blocks_in_2d > 8:
|
||||
return (False,
|
||||
f"Thread block clusters with more than 8 thread blocks are currently unsupported on SM{cc}. "
|
||||
f"Received cluster shape {cluster_shape}, which has {blocks_in_2d} thread blocks.")
|
||||
return (True, "")
|
||||
|
||||
|
||||
def valid_kernel_schedule(cc: int, kernel_schedule: cutlass.KernelScheduleType) -> tuple:
|
||||
"""
|
||||
Checks whether a device with ``cc`` supports ``kernel_schedule``.
|
||||
|
||||
:param cc: compute capability of device in question
|
||||
:type cc: int
|
||||
:param kernel_schedule: kernel schedule type
|
||||
:type KernelScheduleType: cutlass.KernelScheduleType
|
||||
|
||||
:return: tuple with the first element indicating whether the provided kernel schedule is
|
||||
valid for the provided device and the second element being an error message
|
||||
:rtype: tuple
|
||||
"""
|
||||
if kernel_schedule != cutlass.KernelScheduleType.ScheduleAuto and cc < 90:
|
||||
return (False, "Non-default kernel schedules are only supported on SM90 and beyond")
|
||||
return (True, "")
|
||||
|
||||
|
||||
def alignment_or_default(alignment_provided: int, default_alignment: int) -> int:
|
||||
"""
|
||||
Returns `alignment_provided` if it is set, otherwise `default_alignment` and checks
|
||||
that `alignment_provided` does not exceed `default_alignment`.
|
||||
|
||||
:param alignment_provided: alignment preference specified. Can be None.
|
||||
:type alignment_provided: int
|
||||
:param default_alignment: alignment to use if `alignment_provided` is None
|
||||
:type default_alignment: int
|
||||
|
||||
:return: alignment to use
|
||||
:rtype: int
|
||||
"""
|
||||
if alignment_provided is not None:
|
||||
if alignment_provided > default_alignment:
|
||||
raise Exception(f"Alignment {alignment_provided} exceeds the maximum supported of {default_alignment}.")
|
||||
return alignment_provided
|
||||
|
||||
return default_alignment
|
||||
339
python/cutlass/utils/datatypes.py
Normal file
339
python/cutlass/utils/datatypes.py
Normal file
@@ -0,0 +1,339 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Utility functions for converting between frontend datatypes and CUTLASS datatypes
|
||||
"""
|
||||
|
||||
import cutlass_bindings
|
||||
|
||||
import cutlass
|
||||
from cutlass.backend.library import (
|
||||
DataTypeSize,
|
||||
MathInstruction,
|
||||
MathOperation,
|
||||
ShortLayoutTypeNames,
|
||||
TileDescription,
|
||||
)
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
|
||||
numpy_available = True
|
||||
_library_to_numpy_dict = {
|
||||
cutlass.DataType.f16: np.float16,
|
||||
cutlass.DataType.f32: np.float32,
|
||||
cutlass.DataType.f64: np.float64,
|
||||
cutlass.DataType.s8: np.int8,
|
||||
cutlass.DataType.s32: np.int32,
|
||||
}
|
||||
except ImportError:
|
||||
numpy_available = False
|
||||
_library_to_numpy_dict = {}
|
||||
|
||||
|
||||
def numpy_library_type(inp) -> cutlass.DataType:
|
||||
if numpy_available:
|
||||
if inp == np.float16:
|
||||
return cutlass.DataType.f16
|
||||
elif inp == np.float32:
|
||||
return cutlass.DataType.f32
|
||||
elif inp == np.float64:
|
||||
return cutlass.DataType.f64
|
||||
elif inp == np.int8:
|
||||
return cutlass.DataType.s8
|
||||
elif inp == np.int32:
|
||||
return cutlass.DataType.s32
|
||||
return None
|
||||
|
||||
|
||||
def numpy_type(inp):
|
||||
return _library_to_numpy_dict.get(inp, None)
|
||||
|
||||
|
||||
try:
|
||||
import cupy as cp
|
||||
|
||||
cupy_available = True
|
||||
_library_to_cupy_dict = {
|
||||
cutlass.DataType.f16: cp.float16,
|
||||
cutlass.DataType.f32: cp.float32,
|
||||
cutlass.DataType.f64: cp.float64,
|
||||
cutlass.DataType.s8: cp.int8,
|
||||
cutlass.DataType.s32: cp.int32,
|
||||
}
|
||||
except ImportError:
|
||||
cupy_available = False
|
||||
_library_to_cupy_dict = {}
|
||||
|
||||
|
||||
def cupy_library_type(inp) -> cutlass.DataType:
|
||||
if cupy_available:
|
||||
if inp == cp.float16:
|
||||
return cutlass.DataType.f16
|
||||
elif inp == cp.float32:
|
||||
return cutlass.DataType.f32
|
||||
elif inp == cp.float64:
|
||||
return cutlass.DataType.f64
|
||||
return None
|
||||
|
||||
|
||||
def cupy_type(inp):
|
||||
return _library_to_cupy_dict.get(inp, None)
|
||||
|
||||
|
||||
try:
|
||||
import torch
|
||||
|
||||
torch_available = True
|
||||
_torch_to_library_dict = {
|
||||
torch.half: cutlass.DataType.f16,
|
||||
torch.float16: cutlass.DataType.f16,
|
||||
torch.float: cutlass.DataType.f32,
|
||||
torch.float32: cutlass.DataType.f32,
|
||||
torch.double: cutlass.DataType.f64,
|
||||
torch.float64: cutlass.DataType.f64,
|
||||
}
|
||||
|
||||
_library_to_torch_dict = {
|
||||
cutlass.DataType.f16: torch.half,
|
||||
cutlass.DataType.f16: torch.float16,
|
||||
cutlass.DataType.f32: torch.float,
|
||||
cutlass.DataType.f32: torch.float32,
|
||||
cutlass.DataType.f64: torch.double,
|
||||
cutlass.DataType.f64: torch.float64,
|
||||
}
|
||||
except ImportError:
|
||||
torch_available = False
|
||||
_torch_to_library_dict = {}
|
||||
_library_to_torch_dict = {}
|
||||
|
||||
|
||||
def torch_library_type(inp) -> cutlass.DataType:
|
||||
return _torch_to_library_dict.get(inp, None)
|
||||
|
||||
|
||||
def torch_type(inp):
|
||||
return _library_to_torch_dict.get(inp, None)
|
||||
|
||||
|
||||
try:
|
||||
import bfloat16
|
||||
|
||||
bfloat16_available = True
|
||||
except ImportError:
|
||||
bfloat16_available = False
|
||||
|
||||
|
||||
def bfloat16_library_type(inp) -> cutlass.DataType:
|
||||
if bfloat16_available:
|
||||
if inp == bfloat16.bfloat16:
|
||||
return cutlass.DataType.bf16
|
||||
|
||||
|
||||
def bfloat16_type(inp) -> bfloat16.bfloat16:
|
||||
if bfloat16_available:
|
||||
if inp == cutlass.DataType.bf16:
|
||||
return bfloat16.bfloat16
|
||||
|
||||
|
||||
# Mapping from library data type to Python-bound CUTLASS data type
|
||||
library_to_binding_dict = {
|
||||
cutlass.DataType.s8: cutlass_bindings.int8,
|
||||
cutlass.DataType.s32: cutlass_bindings.int32,
|
||||
cutlass.DataType.f16: cutlass_bindings.float16,
|
||||
cutlass.DataType.bf16: cutlass_bindings.bfloat16,
|
||||
cutlass.DataType.f32: cutlass_bindings.float32,
|
||||
cutlass.DataType.f64: cutlass_bindings.float64,
|
||||
cutlass.DataType.tf32: cutlass_bindings.tfloat32,
|
||||
}
|
||||
|
||||
# Mapping from Python-bound CUTLASS data type to library data type
|
||||
binding_to_library = {
|
||||
cutlass_bindings.int8: cutlass.DataType.s8,
|
||||
cutlass_bindings.int32: cutlass.DataType.s32,
|
||||
cutlass_bindings.float16: cutlass.DataType.f16,
|
||||
cutlass_bindings.bfloat16: cutlass.DataType.bf16,
|
||||
cutlass_bindings.float32: cutlass.DataType.f32,
|
||||
cutlass_bindings.float64: cutlass.DataType.f64,
|
||||
cutlass_bindings.tfloat32: cutlass.DataType.tf32,
|
||||
}
|
||||
|
||||
|
||||
def binding_library_type(inp):
|
||||
if inp in binding_to_library:
|
||||
return binding_to_library[inp]
|
||||
return None
|
||||
|
||||
|
||||
def has_binding_type(inp: cutlass.DataType):
|
||||
return inp in library_to_binding_dict
|
||||
|
||||
|
||||
def library_to_binding(inp: cutlass.DataType):
|
||||
if not has_binding_type(inp):
|
||||
raise Exception(f"No available conversion from library type {inp} to Python-bound CUTLASS type")
|
||||
return library_to_binding_dict[inp]
|
||||
|
||||
|
||||
def library_type(inp):
|
||||
if inp in cutlass.DataTypeSize.keys():
|
||||
return inp
|
||||
|
||||
for cvt_fn in [
|
||||
bfloat16_library_type,
|
||||
cupy_library_type,
|
||||
numpy_library_type,
|
||||
torch_library_type,
|
||||
binding_library_type,
|
||||
]:
|
||||
out = cvt_fn(inp)
|
||||
if out is not None:
|
||||
return out
|
||||
|
||||
raise Exception(f"No available conversion from type {inp} to a library type.")
|
||||
|
||||
|
||||
def library_layout(layout):
|
||||
if layout in cutlass.LayoutTag.keys():
|
||||
return layout
|
||||
|
||||
# Convert Python-bound CUTLASS layout to profiler library layout
|
||||
if layout == cutlass_bindings.RowMajor:
|
||||
return cutlass.LayoutType.RowMajor
|
||||
elif layout == cutlass_bindings.ColumnMajor:
|
||||
return cutlass.LayoutType.ColumnMajor
|
||||
else:
|
||||
raise Exception(f"No conversion available for layout {layout} to library layout.")
|
||||
|
||||
|
||||
def binding_type(inp):
|
||||
if inp in DataTypeSize.keys():
|
||||
return inp
|
||||
|
||||
libtype = library_type(inp)
|
||||
return library_to_binding(libtype)
|
||||
|
||||
|
||||
def binding_layout(layout):
|
||||
if layout in ShortLayoutTypeNames.keys():
|
||||
return layout
|
||||
elif layout == cutlass.LayoutType.RowMajor:
|
||||
return cutlass_bindings.RowMajor
|
||||
elif layout == cutlass.LayoutType.ColumnMajor:
|
||||
return cutlass_bindings.ColumnMajor
|
||||
else:
|
||||
raise Exception(f"No conversion available for layout {layout} to Python-bound CUTLASS layout.")
|
||||
|
||||
|
||||
def _tensor_from_numpy(np_tensor):
|
||||
dtype = library_type(np_tensor.dtype)
|
||||
if np_tensor.flags.c_contiguous:
|
||||
layout = cutlass.LayoutType.RowMajor
|
||||
elif np_tensor.flags.f_contiguous:
|
||||
layout = cutlass.LayoutType.ColumnMajor
|
||||
return (dtype, layout)
|
||||
|
||||
|
||||
def _tensor_from_torch(pt_tensor):
|
||||
dtype = library_type(pt_tensor.dtype)
|
||||
return (dtype, cutlass.LayoutType.RowMajor)
|
||||
|
||||
|
||||
def get_datatype_and_layout(tensor):
|
||||
if (numpy_available and isinstance(tensor, np.ndarray)) or (
|
||||
cupy_available and isinstance(tensor, cp.ndarray)
|
||||
):
|
||||
return _tensor_from_numpy(tensor)
|
||||
elif torch_available and isinstance(tensor, torch.Tensor):
|
||||
return _tensor_from_torch(tensor)
|
||||
else:
|
||||
raise Exception(f"Unable to convert tensor of type {type(tensor)} to Python-bound CUTLASS datatype and layout.")
|
||||
|
||||
|
||||
def binding_opclass(opclass: cutlass.OpcodeClass):
|
||||
if opclass == cutlass.OpcodeClass.TensorOp:
|
||||
return cutlass_bindings.OpClass.TensorOp
|
||||
elif opclass == cutlass.OpcodeClass.Simt:
|
||||
return cutlass_bindings.OpClass.Simt
|
||||
else:
|
||||
raise Exception(f"Unable to convert opcode class of type {opclass} to Python-bound CUTLASS opcode class.")
|
||||
|
||||
|
||||
_math_operation_value_map = {x.value: x for x in MathOperation}
|
||||
|
||||
|
||||
def backend_math_operation(math_op: cutlass.MathOperation):
|
||||
if math_op.value not in _math_operation_value_map.keys():
|
||||
raise Exception(f"Unable to convert math operation of type {math_op} to backend math operation.")
|
||||
return _math_operation_value_map[math_op.value]
|
||||
|
||||
|
||||
def construct_backend_td(td: cutlass.TileDescription,
|
||||
kernel_schedule: cutlass.KernelScheduleType) -> TileDescription:
|
||||
mi = td.math_instruction
|
||||
backend_mi = MathInstruction(
|
||||
mi.instruction_shape,
|
||||
binding_type(mi.element_a),
|
||||
binding_type(mi.element_b),
|
||||
binding_type(mi.element_accumulator),
|
||||
binding_opclass(mi.opcode_class),
|
||||
backend_math_operation(mi.math_operation)
|
||||
)
|
||||
return TileDescription(td.threadblock_shape, td.stages, td.warp_count,
|
||||
backend_mi, td.cluster_shape, kernel_schedule)
|
||||
|
||||
|
||||
def td_from_profiler_op(op) -> TileDescription:
|
||||
"""
|
||||
Converts the profiler's TileDescription in ``op`` into the backend TileDescription
|
||||
|
||||
:param op: profiler Operation
|
||||
|
||||
:returns: backend TileDescription
|
||||
:rtype: cutlass.backend.TileDescription
|
||||
"""
|
||||
schedule = op.kernel_schedule if hasattr(op, 'kernel_schedule') else None
|
||||
return construct_backend_td(op.tile_description, schedule)
|
||||
|
||||
|
||||
def td_from_profiler_td(td: cutlass.backend.TileDescription) -> TileDescription:
|
||||
"""
|
||||
Converts the profiler's TileDescription into the backend TileDescription
|
||||
|
||||
:param td: profiler TileDescription
|
||||
:type td: cutlass.TileDescription
|
||||
|
||||
:returns: backend TileDescription
|
||||
:rtype: cutlass.backend.TileDescription
|
||||
"""
|
||||
return construct_backend_td(td, kernel_schedule=None)
|
||||
Reference in New Issue
Block a user