mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-12 17:25:45 +00:00
CUTLASS 3.3.0 (#1167)
* Release 3.3.0 Adds support for mixed precision GEMMs On Hopper and Ampere Adds support for < 16B aligned GEMMs on Hopper Enhancements to EVT Enhancements to Python interface Enhancements to Sub-byte type handling in CuTe Several other bug-fixes and performance improvements. * minor doc update
This commit is contained in:
@@ -36,26 +36,27 @@ Utility functions for checking constraints on kernels and calculating kernel att
|
||||
|
||||
import ctypes
|
||||
|
||||
from cutlass_library import DataTypeSize, OperationKind, SharedMemPerCC
|
||||
|
||||
import cutlass
|
||||
from cutlass import DataTypeSize
|
||||
from cutlass.backend.library import TileDescription
|
||||
|
||||
|
||||
def calculate_smem_usage_per_stage(td: TileDescription, operation_kind: cutlass.OperationKind) -> int:
|
||||
def calculate_smem_usage_per_stage(td: TileDescription, operation_kind: OperationKind) -> int:
|
||||
"""
|
||||
Returns the amount of shared memory in bytes consumed in a single stage of a kernel.
|
||||
|
||||
:param td: tile description to compute shared memory of
|
||||
:type td: TileDescription
|
||||
:param operation_kind: identifier for the type of operation being performed
|
||||
:type operation_kind: cutlass.OperationKind
|
||||
:type operation_kind: cutlass_library.OperationKind
|
||||
|
||||
:return: number of bytes of shared memory consumed by a single stage
|
||||
:rtype: int
|
||||
"""
|
||||
m, n, k = td.threadblock_shape
|
||||
|
||||
if operation_kind == cutlass.OperationKind.Gemm:
|
||||
if operation_kind == OperationKind.Gemm:
|
||||
stage_barrier_bytes = 32
|
||||
return (
|
||||
(DataTypeSize[td.math_instruction.element_a] * m * k // 8)
|
||||
@@ -82,7 +83,8 @@ def valid_stage_count(
|
||||
kernel_cc: int,
|
||||
td: TileDescription,
|
||||
element_C: cutlass.DataType = None,
|
||||
element_D: cutlass.DataType = None) -> tuple:
|
||||
element_D: cutlass.DataType = None,
|
||||
verbose: bool = True) -> tuple:
|
||||
"""
|
||||
Checks whether a device with `cc` supports the number of stages within `tile_description`, both
|
||||
based on raw limits on the number of stages and based on shared memory capacity
|
||||
@@ -97,6 +99,8 @@ def valid_stage_count(
|
||||
:type element_C: cutlass.DataType
|
||||
:param element_D: data type of operand D
|
||||
:type element_D: cutlass.DataType
|
||||
:param verbose: whether to log warnings
|
||||
:type verbose: bool
|
||||
|
||||
:return: tuple with the first element indicating whether the provided tile description is
|
||||
valid for the provided device and the second element being an error message
|
||||
@@ -107,7 +111,7 @@ def valid_stage_count(
|
||||
# Stage count of None or 0 for SM90 indicates that the CollectiveBuilder automatically
|
||||
# determines the stage count to use. Thus, all settings are valid in these scenarios.
|
||||
return (True, "")
|
||||
else:
|
||||
elif verbose:
|
||||
cutlass.logger.warning(
|
||||
"Setting an explicit stage count for SM90 kernels currently may "
|
||||
"result in compilation errors if the combination of tile shape, "
|
||||
@@ -125,9 +129,9 @@ def valid_stage_count(
|
||||
# only catches cases in which the mainloop exceeds the device's shared memory capacity.
|
||||
# This is not a concern for CUTLASS 2.x kernels, for which the shared memory of the
|
||||
# mainloop and epilogue is shared.
|
||||
smem_per_stage = calculate_smem_usage_per_stage(td, cutlass.OperationKind.Gemm)
|
||||
smem_per_stage = calculate_smem_usage_per_stage(td, OperationKind.Gemm)
|
||||
smem_usage_mainloop = (smem_per_stage * td.stages)
|
||||
smem_arch = cutlass.SharedMemPerCC[cc] << 10
|
||||
smem_arch = SharedMemPerCC[cc] << 10
|
||||
if smem_usage_mainloop > smem_arch:
|
||||
return ( False,
|
||||
"Configuration uses too much shared memory. Consider reducing stage count or tile shape.\n"
|
||||
@@ -214,7 +218,9 @@ def valid_schedule(
|
||||
return (False, "Kernel and epilogue schedules must either both be auto or neither be auto")
|
||||
|
||||
if not tile_scheduler_default:
|
||||
if (tile_scheduler == cutlass.TileSchedulerType.StreamK) and (kernel_schedule != cutlass.KernelScheduleType.TmaWarpSpecializedCooperative):
|
||||
cooperative_kernels = [cutlass.KernelScheduleType.TmaWarpSpecializedCooperative,
|
||||
cutlass.KernelScheduleType.CpAsyncWarpSpecializedCooperative]
|
||||
if (tile_scheduler == cutlass.TileSchedulerType.StreamK) and (kernel_schedule not in cooperative_kernels):
|
||||
return (False, "Stream-K tile scheduler is currently only supported with the cooperative kernel schedule")
|
||||
return (True, "")
|
||||
|
||||
|
||||
@@ -35,33 +35,55 @@ Utility functions for converting between frontend datatypes and CUTLASS datatype
|
||||
"""
|
||||
|
||||
import cutlass
|
||||
from cutlass import (
|
||||
from cutlass_library import (
|
||||
DataTypeSize,
|
||||
MathOperation,
|
||||
MathInstruction
|
||||
)
|
||||
from cutlass.backend.library import (
|
||||
MathInstruction,
|
||||
MathOperation,
|
||||
TileDescription,
|
||||
)
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
bfloat16_available = None
|
||||
cupy_available = None
|
||||
numpy_available = None
|
||||
torch_available = None
|
||||
_library_to_cupy_dict = None
|
||||
_library_to_numpy_dict = None
|
||||
_library_to_torch_dict = None
|
||||
_torch_to_library_dict = None
|
||||
|
||||
numpy_available = True
|
||||
_library_to_numpy_dict = {
|
||||
cutlass.DataType.f16: np.float16,
|
||||
cutlass.DataType.f32: np.float32,
|
||||
cutlass.DataType.f64: np.float64,
|
||||
cutlass.DataType.s8: np.int8,
|
||||
cutlass.DataType.s32: np.int32,
|
||||
}
|
||||
except ImportError:
|
||||
numpy_available = False
|
||||
_library_to_numpy_dict = {}
|
||||
|
||||
def is_numpy_available():
|
||||
global numpy_available, _library_to_numpy_dict
|
||||
if numpy_available is None:
|
||||
try:
|
||||
import numpy as np
|
||||
|
||||
numpy_available = True
|
||||
_library_to_numpy_dict = {
|
||||
cutlass.DataType.f16: np.float16,
|
||||
cutlass.DataType.f32: np.float32,
|
||||
cutlass.DataType.f64: np.float64,
|
||||
cutlass.DataType.s8: np.int8,
|
||||
cutlass.DataType.s32: np.int32,
|
||||
}
|
||||
except ImportError:
|
||||
numpy_available = False
|
||||
_library_to_numpy_dict = {}
|
||||
return numpy_available
|
||||
|
||||
|
||||
def is_numpy_tensor(inp) -> bool:
|
||||
if is_numpy_available():
|
||||
import numpy as np
|
||||
return isinstance(inp, np.ndarray)
|
||||
return False
|
||||
|
||||
|
||||
def numpy_library_type(inp) -> cutlass.DataType:
|
||||
if numpy_available:
|
||||
if is_numpy_available():
|
||||
import numpy as np
|
||||
if inp == np.float16:
|
||||
return cutlass.DataType.f16
|
||||
elif inp == np.float32:
|
||||
@@ -79,24 +101,36 @@ def numpy_type(inp):
|
||||
return _library_to_numpy_dict.get(inp, None)
|
||||
|
||||
|
||||
try:
|
||||
import cupy as cp
|
||||
def is_cupy_available():
|
||||
global cupy_available
|
||||
if cupy_available is None:
|
||||
try:
|
||||
import cupy as cp
|
||||
|
||||
cupy_available = True
|
||||
_library_to_cupy_dict = {
|
||||
cutlass.DataType.f16: cp.float16,
|
||||
cutlass.DataType.f32: cp.float32,
|
||||
cutlass.DataType.f64: cp.float64,
|
||||
cutlass.DataType.s8: cp.int8,
|
||||
cutlass.DataType.s32: cp.int32,
|
||||
}
|
||||
except ImportError:
|
||||
cupy_available = False
|
||||
_library_to_cupy_dict = {}
|
||||
cupy_available = True
|
||||
_library_to_cupy_dict = {
|
||||
cutlass.DataType.f16: cp.float16,
|
||||
cutlass.DataType.f32: cp.float32,
|
||||
cutlass.DataType.f64: cp.float64,
|
||||
cutlass.DataType.s8: cp.int8,
|
||||
cutlass.DataType.s32: cp.int32,
|
||||
}
|
||||
except ImportError:
|
||||
cupy_available = False
|
||||
_library_to_cupy_dict = {}
|
||||
return cupy_available
|
||||
|
||||
|
||||
def is_cupy_tensor(inp) -> bool:
|
||||
if is_cupy_available():
|
||||
import cupy as cp
|
||||
return isinstance(inp, cp.ndarray)
|
||||
return False
|
||||
|
||||
|
||||
def cupy_library_type(inp) -> cutlass.DataType:
|
||||
if cupy_available:
|
||||
if is_cupy_available():
|
||||
import cupy as cp
|
||||
if inp == cp.float16:
|
||||
return cutlass.DataType.f16
|
||||
elif inp == cp.float32:
|
||||
@@ -110,39 +144,50 @@ def cupy_type(inp):
|
||||
return _library_to_cupy_dict.get(inp, None)
|
||||
|
||||
|
||||
try:
|
||||
import torch
|
||||
def is_torch_available():
|
||||
global torch_available, _library_to_torch_dict, _torch_to_library_dict
|
||||
if torch_available is None:
|
||||
try:
|
||||
import torch
|
||||
|
||||
torch_available = True
|
||||
_torch_to_library_dict = {
|
||||
torch.half: cutlass.DataType.f16,
|
||||
torch.float16: cutlass.DataType.f16,
|
||||
torch.bfloat16: cutlass.DataType.bf16,
|
||||
torch.float: cutlass.DataType.f32,
|
||||
torch.float32: cutlass.DataType.f32,
|
||||
torch.double: cutlass.DataType.f64,
|
||||
torch.float64: cutlass.DataType.f64,
|
||||
torch.int8: cutlass.DataType.s8,
|
||||
torch.int32: cutlass.DataType.s32,
|
||||
torch.uint8: cutlass.DataType.u8,
|
||||
}
|
||||
torch_available = True
|
||||
_torch_to_library_dict = {
|
||||
torch.half: cutlass.DataType.f16,
|
||||
torch.float16: cutlass.DataType.f16,
|
||||
torch.bfloat16: cutlass.DataType.bf16,
|
||||
torch.float: cutlass.DataType.f32,
|
||||
torch.float32: cutlass.DataType.f32,
|
||||
torch.double: cutlass.DataType.f64,
|
||||
torch.float64: cutlass.DataType.f64,
|
||||
torch.int8: cutlass.DataType.s8,
|
||||
torch.int32: cutlass.DataType.s32,
|
||||
torch.uint8: cutlass.DataType.u8,
|
||||
}
|
||||
|
||||
_library_to_torch_dict = {
|
||||
cutlass.DataType.f16: torch.half,
|
||||
cutlass.DataType.f16: torch.float16,
|
||||
cutlass.DataType.bf16: torch.bfloat16,
|
||||
cutlass.DataType.f32: torch.float,
|
||||
cutlass.DataType.f32: torch.float32,
|
||||
cutlass.DataType.f64: torch.double,
|
||||
cutlass.DataType.f64: torch.float64,
|
||||
cutlass.DataType.s8: torch.int8,
|
||||
cutlass.DataType.s32: torch.int32,
|
||||
cutlass.DataType.u8: torch.uint8,
|
||||
}
|
||||
except ImportError:
|
||||
torch_available = False
|
||||
_torch_to_library_dict = {}
|
||||
_library_to_torch_dict = {}
|
||||
_library_to_torch_dict = {
|
||||
cutlass.DataType.f16: torch.half,
|
||||
cutlass.DataType.f16: torch.float16,
|
||||
cutlass.DataType.bf16: torch.bfloat16,
|
||||
cutlass.DataType.f32: torch.float,
|
||||
cutlass.DataType.f32: torch.float32,
|
||||
cutlass.DataType.f64: torch.double,
|
||||
cutlass.DataType.f64: torch.float64,
|
||||
cutlass.DataType.s8: torch.int8,
|
||||
cutlass.DataType.s32: torch.int32,
|
||||
cutlass.DataType.u8: torch.uint8,
|
||||
}
|
||||
except ImportError:
|
||||
torch_available = False
|
||||
_torch_to_library_dict = {}
|
||||
_library_to_torch_dict = {}
|
||||
return torch_available
|
||||
|
||||
|
||||
def is_torch_tensor(inp) -> bool:
|
||||
if is_torch_available():
|
||||
import torch
|
||||
return isinstance(inp, torch.Tensor)
|
||||
return False
|
||||
|
||||
|
||||
def torch_library_type(inp) -> cutlass.DataType:
|
||||
@@ -153,28 +198,35 @@ def torch_type(inp):
|
||||
return _library_to_torch_dict.get(inp, None)
|
||||
|
||||
|
||||
try:
|
||||
import bfloat16
|
||||
def is_bfloat16_available():
|
||||
global bfloat16_available
|
||||
|
||||
bfloat16_available = True
|
||||
except ImportError:
|
||||
bfloat16_available = False
|
||||
if bfloat16_available is None:
|
||||
try:
|
||||
import bfloat16
|
||||
|
||||
bfloat16_available = True
|
||||
except ImportError:
|
||||
bfloat16_available = False
|
||||
return bfloat16_available
|
||||
|
||||
|
||||
def bfloat16_library_type(inp) -> cutlass.DataType:
|
||||
if bfloat16_available:
|
||||
if is_bfloat16_available():
|
||||
import bfloat16
|
||||
if inp == bfloat16.bfloat16:
|
||||
return cutlass.DataType.bf16
|
||||
|
||||
|
||||
def bfloat16_type(inp):
|
||||
if bfloat16_available:
|
||||
if is_bfloat16_available():
|
||||
import bfloat16
|
||||
if inp == cutlass.DataType.bf16:
|
||||
return bfloat16.bfloat16
|
||||
|
||||
|
||||
def library_type(inp):
|
||||
if inp in cutlass.DataTypeSize.keys():
|
||||
if inp in DataTypeSize:
|
||||
return inp
|
||||
|
||||
for cvt_fn in [
|
||||
@@ -205,23 +257,20 @@ def _tensor_from_torch(pt_tensor):
|
||||
|
||||
|
||||
def get_datatype_and_layout(tensor):
|
||||
if (numpy_available and isinstance(tensor, np.ndarray)) or (
|
||||
cupy_available and isinstance(tensor, cp.ndarray)
|
||||
):
|
||||
if (is_numpy_tensor(tensor) or is_cupy_tensor(tensor)):
|
||||
return _tensor_from_numpy(tensor)
|
||||
elif torch_available and isinstance(tensor, torch.Tensor):
|
||||
elif is_torch_tensor(tensor):
|
||||
return _tensor_from_torch(tensor)
|
||||
elif isinstance(tensor, float) or isinstance(tensor, int):
|
||||
return (cutlass.DataType.f32, cutlass.LayoutType.RowMajor)
|
||||
else:
|
||||
raise Exception(f"Unable to convert tensor of type {type(tensor)} to Python-bound CUTLASS datatype and layout.")
|
||||
|
||||
|
||||
def get_tensor_shape(tensor, op="GEMM"):
|
||||
if (numpy_available and isinstance(tensor, np.ndarray)) or (
|
||||
cupy_available and isinstance(tensor, cp.ndarray)
|
||||
):
|
||||
if (is_numpy_tensor(tensor) or is_cupy_tensor(tensor)):
|
||||
return tensor.shape
|
||||
elif torch_available and isinstance(tensor, torch.Tensor):
|
||||
elif is_torch_tensor(tensor):
|
||||
size = tensor.size()
|
||||
if op == "CONV":
|
||||
# PyTorch Tensors have shape NCHW
|
||||
@@ -237,7 +286,7 @@ def get_tensor_shape(tensor, op="GEMM"):
|
||||
_math_operation_value_map = {x.value: x for x in MathOperation}
|
||||
|
||||
|
||||
def backend_math_operation(math_op: cutlass.MathOperation):
|
||||
def backend_math_operation(math_op: MathOperation):
|
||||
if math_op.value not in _math_operation_value_map.keys():
|
||||
raise Exception(f"Unable to convert math operation of type {math_op} to backend math operation.")
|
||||
return _math_operation_value_map[math_op.value]
|
||||
|
||||
185
python/cutlass/utils/profiler.py
Normal file
185
python/cutlass/utils/profiler.py
Normal file
@@ -0,0 +1,185 @@
|
||||
#################################################################################################
|
||||
#
|
||||
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
#
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Profiler based on the cuda events
|
||||
"""
|
||||
|
||||
import re
|
||||
import subprocess
|
||||
|
||||
from cuda import cuda, cudart
|
||||
import numpy as np
|
||||
|
||||
from cutlass import CUTLASS_PATH
|
||||
from cutlass.backend.library import DataTypeSize
|
||||
from cutlass.op.op import OperationBase
|
||||
from cutlass.shape import GemmCoord
|
||||
from cutlass.utils.datatypes import is_numpy_tensor
|
||||
|
||||
|
||||
class GpuTimer:
|
||||
def __init__(self) -> None:
|
||||
self.events = [
|
||||
cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1],
|
||||
cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1],
|
||||
]
|
||||
|
||||
def start(self, stream=cuda.CUstream(0)):
|
||||
(err,) = cuda.cuEventRecord(self.events[0], stream)
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError(f"CUDA Error {str(err)}")
|
||||
|
||||
def stop(self, stream=cuda.CUstream(0)):
|
||||
(err,) = cuda.cuEventRecord(self.events[1], stream)
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError(f"CUDA Error {str(err)}")
|
||||
pass
|
||||
|
||||
def stop_and_wait(self, stream=cuda.CUstream(0)):
|
||||
self.stop(stream)
|
||||
if stream:
|
||||
(err,) = cuda.cuStreamSynchronize(stream)
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError(f"CUDA Error {str(err)}")
|
||||
else:
|
||||
(err,) = cudart.cudaDeviceSynchronize()
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError(f"CUDA Error {str(err)}")
|
||||
|
||||
def duration(self, iterations=1):
|
||||
err, duration = cuda.cuEventElapsedTime(self.events[0], self.events[1])
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError(f"CUDA Error {str(err)}")
|
||||
return duration / float(iterations)
|
||||
|
||||
|
||||
class CUDAEventProfiler:
|
||||
def __init__(self, op: OperationBase, warmup_iterations: int=500, iterations: int=500, *args, **kwargs) -> None:
|
||||
self.arguments = op.run(*args, **kwargs)
|
||||
self.operation = op.operation
|
||||
self.warmup_iterations = warmup_iterations
|
||||
self.iterations = iterations
|
||||
self.timer = GpuTimer()
|
||||
|
||||
#
|
||||
# Cutlass Python Interface Profiler
|
||||
#
|
||||
|
||||
def __call__(self):
|
||||
for _ in range(self.warmup_iterations):
|
||||
self.operation.run(self.arguments)
|
||||
|
||||
self.timer.start()
|
||||
for _ in range(self.iterations):
|
||||
self.operation.run(self.arguments)
|
||||
|
||||
self.timer.stop_and_wait()
|
||||
runtime = self.timer.duration(self.iterations)
|
||||
return runtime
|
||||
|
||||
#
|
||||
# CUTLASS Profiler
|
||||
#
|
||||
|
||||
def run_cutlass_profiler(self):
|
||||
alpha = 1.0
|
||||
beta = 1.0
|
||||
|
||||
profiler_path = CUTLASS_PATH + "/build/tools/profiler/cutlass_profiler"
|
||||
kernel_name = self.operation.procedural_name()
|
||||
verification_providers = "device"
|
||||
provider = "cutlass"
|
||||
problem_size = self.arguments.problem_size
|
||||
|
||||
if "cutlass3x" in kernel_name:
|
||||
# cutlass3x generator only have column-major output
|
||||
layout_name = self.operation.layout_name_3x()
|
||||
if layout_name[-1] == "t":
|
||||
new_layout_name = "".join(["n" for l in layout_name if l == "t" or "t"])
|
||||
problem_size = GemmCoord(problem_size.n, problem_size.m, problem_size.k)
|
||||
kernel_name = kernel_name.replace(layout_name, new_layout_name)
|
||||
|
||||
batch_count = self.arguments.batch_count
|
||||
|
||||
cmd = f"{profiler_path} --kernels={kernel_name} --verification-providers={verification_providers} " \
|
||||
f"--providers={provider} --m={problem_size.m()} --n={problem_size.n()} --k={problem_size.k()} " \
|
||||
f"--batch_count={batch_count} --alpha={alpha} --beta={beta} "\
|
||||
f"--warmup-iterations={self.warmup_iterations} --profiling-iterations={self.iterations}"
|
||||
|
||||
result = subprocess.getoutput(cmd)
|
||||
|
||||
m = re.search(r"Runtime:\s+(?P<runtime>\d+.\d+)", result)
|
||||
runtime = float(m.group("runtime"))
|
||||
|
||||
m = re.search(r"Bytes:\s+(?P<bytes>\d+)", result)
|
||||
bytes = int(m.group("bytes"))
|
||||
|
||||
m = re.search(r"FLOPs:\s+(?P<flops>\d+)", result)
|
||||
flops = int(m.group("flops"))
|
||||
|
||||
# check if the problem size matches
|
||||
assert bytes == self.bytes(problem_size, batch_count, beta)
|
||||
assert flops == self.flops(problem_size, batch_count, beta)
|
||||
|
||||
return runtime
|
||||
|
||||
def bytes(self, problem_size, batch_count=1, beta=0.0):
|
||||
m = problem_size.m()
|
||||
n = problem_size.n()
|
||||
k = problem_size.k()
|
||||
|
||||
bytes = (
|
||||
(DataTypeSize[self.operation.A.element] * m // 8) * k
|
||||
+ (DataTypeSize[self.operation.B.element] * n // 8) * k
|
||||
+ (DataTypeSize[self.operation.C.element] * m // 8) * n
|
||||
)
|
||||
|
||||
if beta != 0:
|
||||
bytes += (DataTypeSize[self.operation.C.element] * m // 8) * n
|
||||
|
||||
bytes *= batch_count
|
||||
|
||||
return bytes
|
||||
|
||||
def flops(self, problem_size, batch_count=1, beta=0.0):
|
||||
m = problem_size.m()
|
||||
n = problem_size.n()
|
||||
k = problem_size.k()
|
||||
|
||||
flops_ = (m * n * k) * 2 * batch_count
|
||||
|
||||
if beta != 0:
|
||||
flops_ += m * n * batch_count * 2
|
||||
|
||||
return flops_
|
||||
|
||||
Reference in New Issue
Block a user