v4.4 tag release update. (#3032)

This commit is contained in:
Junkai-Wu
2026-02-14 12:27:58 +08:00
committed by GitHub
parent 01687cfba1
commit d4bbf728ca
140 changed files with 41624 additions and 3691 deletions

View File

@@ -54,6 +54,7 @@ class MatmulHeuristics:
def __init__(self, gpu = None):
import nvMatmulHeuristics
import inspect
self.mmh_lib = nvMatmulHeuristics
self.gpu = gpu
@@ -62,13 +63,63 @@ class MatmulHeuristics:
else:
nvmmhInterfaceEx = self.mmh_lib.NvMatmulHeuristicsInterfaceEx
self.lh = nvmmhInterfaceEx(
# nvidia-matmul-heuristics 0.1.0.28 changed the API:
# - Constructor: removed 'load_discovery_implicitly' and 'gpu' params
# - GPU: now set via createHardwareDescriptor() + setHardwarePredefinedGpu()
# - setBackendValueProperty renamed to setBackendPropertyValue (simpler signature)
# - getEx: added hardware_descriptor parameter
init_params = set(inspect.signature(self.mmh_lib.NvMatmulHeuristicsInterfaceEx.__init__).parameters.keys())
self._legacy_api = 'load_discovery_implicitly' in init_params
init_kwargs = dict(
backend=self.mmh_lib.NvMatmulHeuristicsTarget["CUTLASS3"],
flags=self.mmh_lib.NvMatmulHeuristicsFlags.PERF_MODEL_BASED_AUTO_TUNING,
load_discovery_implicitly=True,
gpu=self.mmh_lib.NvMatmulHeuristicsNvidiaGpu[self.gpu] if self.gpu else None
)
if self._legacy_api:
# <= 0.1.0.27
init_kwargs['gpu'] = self.mmh_lib.NvMatmulHeuristicsNvidiaGpu[self.gpu] if self.gpu else None
init_kwargs['load_discovery_implicitly'] = True
self.lh = nvmmhInterfaceEx(**init_kwargs)
# >= 0.1.0.28: gpu is set via hardware descriptor after construction,
# and passed to getEx() calls
self.hw_desc = None
if not self._legacy_api and self.gpu:
self.hw_desc = self.lh.createHardwareDescriptor()
if self.hw_desc is None:
raise RuntimeError("Failed to create hardware descriptor for GPU: " + self.gpu)
self.lh.setHardwarePredefinedGpu(self.hw_desc, self.mmh_lib.NvMatmulHeuristicsNvidiaGpu[self.gpu])
self.backend = self.lh.createBackend(self.mmh_lib.NvMatmulHeuristicsTarget["CUTLASS3"])
if not self._legacy_api:
lh = self.lh
original_del = type(lh).__del__
def _safe_del(self_lh):
try:
original_del(self_lh)
except Exception:
pass
type(lh).__del__ = _safe_del
def __del__(self):
"""Clean up resources in correct order before the library's __del__ runs."""
try:
if hasattr(self, 'backend') and self.backend:
self.lh.destroyBackend(self.backend)
self.backend = None
if hasattr(self, 'hw_desc') and self.hw_desc:
self.lh.destroyHardwareDescriptor(self.hw_desc)
self.hw_desc = None
# Null out the handle so the library's __del__ skips nvMatmulHeuristicsDestroy
if hasattr(self, 'lh') and self.lh and hasattr(self.lh, 'handle'):
self.lh.handle = None
except Exception:
pass
def _layout_from_cutlass(self, layouts):
assert(len(layouts)==3)
@@ -98,41 +149,45 @@ class MatmulHeuristics:
else:
return a_c + dtype_to_cublas[dtype_b] + dtype_to_cublas[dtype_c] + dtype_to_cublas[dtype_compute] + dtype_to_cublas[dtype_d]
def _set_backend_property(self, property, value):
"""Compat wrapper: setBackendValueProperty (<=0.1.0.27) vs setBackendPropertyValue (>=0.1.0.28)"""
if self._legacy_api:
c_val = ctypes.c_int(value)
self.lh.setBackendValueProperty(
self.backend, property,
ctypes.byref(c_val), ctypes.sizeof(c_val)
)
else:
self.lh.setBackendPropertyValue(self.backend, property, value)
def set_cta_div_n(self, div_n):
cta_n_div_requirement = ctypes.c_int(div_n)
self.lh.setBackendValueProperty(
self.backend,
self.mmh_lib.NvMatmulHeuristicsBackendProperty.CTA_TILE_N_DIV_REQUIREMENT,
ctypes.byref(cta_n_div_requirement),
ctypes.sizeof(cta_n_div_requirement)
)
self._set_backend_property(
self.mmh_lib.NvMatmulHeuristicsBackendProperty.CTA_TILE_N_DIV_REQUIREMENT, div_n)
def set_cta_div_m(self, div_m):
cta_m_div_requirement = ctypes.c_int(div_m)
self.lh.setBackendValueProperty(
self.backend,
self.mmh_lib.NvMatmulHeuristicsBackendProperty.CTA_TILE_M_DIV_REQUIREMENT,
ctypes.byref(cta_m_div_requirement),
ctypes.sizeof(cta_m_div_requirement)
)
self._set_backend_property(
self.mmh_lib.NvMatmulHeuristicsBackendProperty.CTA_TILE_M_DIV_REQUIREMENT, div_m)
def get_configs(self, m, n, k, batch_count, dtypes, layouts, align_a, align_b, voidC=False, use_fast_acc=True, count=1):
if use_fast_acc:
disable_fast_acc_for_fp8 = ctypes.c_int(0)
else:
disable_fast_acc_for_fp8 = ctypes.c_int(1)
self.lh.setBackendValueProperty(
self.backend,
self._set_backend_property(
self.mmh_lib.NvMatmulHeuristicsBackendProperty.DISABLE_FAST_ACC_FOR_FP8,
ctypes.byref(disable_fast_acc_for_fp8),
ctypes.sizeof(disable_fast_acc_for_fp8)
0 if use_fast_acc else 1
)
precision = self._precision_from_cutlass_dtypes(dtypes)
layout = self._layout_from_cutlass(layouts)
matmul_problem = self.lh.makeNvMatmulHeuristicsProblem(m, n, k, layout, batch_count)
configs = self.lh.getEx(matmul_problem, count, self.backend, precision=precision)
if self._legacy_api:
matmul_problem = self.lh.makeNvMatmulHeuristicsProblem(m, n, k, layout, batch_count)
else:
# >= 0.1.0.28: takes (m,n,k) as a tuple
matmul_problem = self.lh.makeNvMatmulHeuristicsProblem((m, n, k), layout, batch_count)
getEx_kwargs = dict(precision=precision)
if not self._legacy_api:
# >= 0.1.0.28: pass hardware descriptor to getEx
getEx_kwargs['hardware_descriptor'] = self.hw_desc
configs = self.lh.getEx(matmul_problem, count, self.backend, **getEx_kwargs)
ret = []
for c in configs: