mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-13 01:35:45 +00:00
v4.4 tag release update. (#3032)
This commit is contained in:
@@ -54,6 +54,7 @@ class MatmulHeuristics:
|
||||
|
||||
def __init__(self, gpu = None):
|
||||
import nvMatmulHeuristics
|
||||
import inspect
|
||||
self.mmh_lib = nvMatmulHeuristics
|
||||
self.gpu = gpu
|
||||
|
||||
@@ -62,13 +63,63 @@ class MatmulHeuristics:
|
||||
else:
|
||||
nvmmhInterfaceEx = self.mmh_lib.NvMatmulHeuristicsInterfaceEx
|
||||
|
||||
self.lh = nvmmhInterfaceEx(
|
||||
# nvidia-matmul-heuristics 0.1.0.28 changed the API:
|
||||
# - Constructor: removed 'load_discovery_implicitly' and 'gpu' params
|
||||
# - GPU: now set via createHardwareDescriptor() + setHardwarePredefinedGpu()
|
||||
# - setBackendValueProperty renamed to setBackendPropertyValue (simpler signature)
|
||||
# - getEx: added hardware_descriptor parameter
|
||||
init_params = set(inspect.signature(self.mmh_lib.NvMatmulHeuristicsInterfaceEx.__init__).parameters.keys())
|
||||
self._legacy_api = 'load_discovery_implicitly' in init_params
|
||||
|
||||
init_kwargs = dict(
|
||||
backend=self.mmh_lib.NvMatmulHeuristicsTarget["CUTLASS3"],
|
||||
flags=self.mmh_lib.NvMatmulHeuristicsFlags.PERF_MODEL_BASED_AUTO_TUNING,
|
||||
load_discovery_implicitly=True,
|
||||
gpu=self.mmh_lib.NvMatmulHeuristicsNvidiaGpu[self.gpu] if self.gpu else None
|
||||
)
|
||||
|
||||
if self._legacy_api:
|
||||
# <= 0.1.0.27
|
||||
init_kwargs['gpu'] = self.mmh_lib.NvMatmulHeuristicsNvidiaGpu[self.gpu] if self.gpu else None
|
||||
init_kwargs['load_discovery_implicitly'] = True
|
||||
|
||||
self.lh = nvmmhInterfaceEx(**init_kwargs)
|
||||
|
||||
# >= 0.1.0.28: gpu is set via hardware descriptor after construction,
|
||||
# and passed to getEx() calls
|
||||
self.hw_desc = None
|
||||
if not self._legacy_api and self.gpu:
|
||||
self.hw_desc = self.lh.createHardwareDescriptor()
|
||||
if self.hw_desc is None:
|
||||
raise RuntimeError("Failed to create hardware descriptor for GPU: " + self.gpu)
|
||||
self.lh.setHardwarePredefinedGpu(self.hw_desc, self.mmh_lib.NvMatmulHeuristicsNvidiaGpu[self.gpu])
|
||||
|
||||
self.backend = self.lh.createBackend(self.mmh_lib.NvMatmulHeuristicsTarget["CUTLASS3"])
|
||||
|
||||
if not self._legacy_api:
|
||||
lh = self.lh
|
||||
original_del = type(lh).__del__
|
||||
|
||||
def _safe_del(self_lh):
|
||||
try:
|
||||
original_del(self_lh)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
type(lh).__del__ = _safe_del
|
||||
|
||||
def __del__(self):
|
||||
"""Clean up resources in correct order before the library's __del__ runs."""
|
||||
try:
|
||||
if hasattr(self, 'backend') and self.backend:
|
||||
self.lh.destroyBackend(self.backend)
|
||||
self.backend = None
|
||||
if hasattr(self, 'hw_desc') and self.hw_desc:
|
||||
self.lh.destroyHardwareDescriptor(self.hw_desc)
|
||||
self.hw_desc = None
|
||||
# Null out the handle so the library's __del__ skips nvMatmulHeuristicsDestroy
|
||||
if hasattr(self, 'lh') and self.lh and hasattr(self.lh, 'handle'):
|
||||
self.lh.handle = None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _layout_from_cutlass(self, layouts):
|
||||
assert(len(layouts)==3)
|
||||
@@ -98,41 +149,45 @@ class MatmulHeuristics:
|
||||
else:
|
||||
return a_c + dtype_to_cublas[dtype_b] + dtype_to_cublas[dtype_c] + dtype_to_cublas[dtype_compute] + dtype_to_cublas[dtype_d]
|
||||
|
||||
def _set_backend_property(self, property, value):
|
||||
"""Compat wrapper: setBackendValueProperty (<=0.1.0.27) vs setBackendPropertyValue (>=0.1.0.28)"""
|
||||
if self._legacy_api:
|
||||
c_val = ctypes.c_int(value)
|
||||
self.lh.setBackendValueProperty(
|
||||
self.backend, property,
|
||||
ctypes.byref(c_val), ctypes.sizeof(c_val)
|
||||
)
|
||||
else:
|
||||
self.lh.setBackendPropertyValue(self.backend, property, value)
|
||||
|
||||
def set_cta_div_n(self, div_n):
|
||||
cta_n_div_requirement = ctypes.c_int(div_n)
|
||||
self.lh.setBackendValueProperty(
|
||||
self.backend,
|
||||
self.mmh_lib.NvMatmulHeuristicsBackendProperty.CTA_TILE_N_DIV_REQUIREMENT,
|
||||
ctypes.byref(cta_n_div_requirement),
|
||||
ctypes.sizeof(cta_n_div_requirement)
|
||||
)
|
||||
self._set_backend_property(
|
||||
self.mmh_lib.NvMatmulHeuristicsBackendProperty.CTA_TILE_N_DIV_REQUIREMENT, div_n)
|
||||
|
||||
def set_cta_div_m(self, div_m):
|
||||
cta_m_div_requirement = ctypes.c_int(div_m)
|
||||
self.lh.setBackendValueProperty(
|
||||
self.backend,
|
||||
self.mmh_lib.NvMatmulHeuristicsBackendProperty.CTA_TILE_M_DIV_REQUIREMENT,
|
||||
ctypes.byref(cta_m_div_requirement),
|
||||
ctypes.sizeof(cta_m_div_requirement)
|
||||
)
|
||||
self._set_backend_property(
|
||||
self.mmh_lib.NvMatmulHeuristicsBackendProperty.CTA_TILE_M_DIV_REQUIREMENT, div_m)
|
||||
|
||||
def get_configs(self, m, n, k, batch_count, dtypes, layouts, align_a, align_b, voidC=False, use_fast_acc=True, count=1):
|
||||
if use_fast_acc:
|
||||
disable_fast_acc_for_fp8 = ctypes.c_int(0)
|
||||
else:
|
||||
disable_fast_acc_for_fp8 = ctypes.c_int(1)
|
||||
self.lh.setBackendValueProperty(
|
||||
self.backend,
|
||||
self._set_backend_property(
|
||||
self.mmh_lib.NvMatmulHeuristicsBackendProperty.DISABLE_FAST_ACC_FOR_FP8,
|
||||
ctypes.byref(disable_fast_acc_for_fp8),
|
||||
ctypes.sizeof(disable_fast_acc_for_fp8)
|
||||
0 if use_fast_acc else 1
|
||||
)
|
||||
|
||||
precision = self._precision_from_cutlass_dtypes(dtypes)
|
||||
layout = self._layout_from_cutlass(layouts)
|
||||
|
||||
matmul_problem = self.lh.makeNvMatmulHeuristicsProblem(m, n, k, layout, batch_count)
|
||||
configs = self.lh.getEx(matmul_problem, count, self.backend, precision=precision)
|
||||
if self._legacy_api:
|
||||
matmul_problem = self.lh.makeNvMatmulHeuristicsProblem(m, n, k, layout, batch_count)
|
||||
else:
|
||||
# >= 0.1.0.28: takes (m,n,k) as a tuple
|
||||
matmul_problem = self.lh.makeNvMatmulHeuristicsProblem((m, n, k), layout, batch_count)
|
||||
|
||||
getEx_kwargs = dict(precision=precision)
|
||||
if not self._legacy_api:
|
||||
# >= 0.1.0.28: pass hardware descriptor to getEx
|
||||
getEx_kwargs['hardware_descriptor'] = self.hw_desc
|
||||
configs = self.lh.getEx(matmul_problem, count, self.backend, **getEx_kwargs)
|
||||
|
||||
ret = []
|
||||
for c in configs:
|
||||
|
||||
Reference in New Issue
Block a user