[fix](kt-kernel): fix double mem used by safetensor loader (#1997)

Release the SafeTensor mmap loader singleton after each layer's
load_weights() completes. The C++ engine already holds a deep copy
(cpu_infer.sync() guarantees this), so releasing the mmap handles is
safe. The next layer recreates the loader on demand.

This halves peak memory usage during model loading (e.g. DSv3.2:
1.2T -> 613G).

Based on #1966 by @poryfly — adapted to v0.6.2.post3 codebase
(adds MXFP4 support missing from the original PR).

Co-authored-by: xiongchenhui <xiongchenhui@hisense.com>
This commit is contained in:
Benjamin F
2026-05-11 12:00:30 +08:00
committed by GitHub
parent bb15fdf47e
commit f05b4009f3
3 changed files with 268 additions and 18 deletions

View File

@@ -1,8 +1,12 @@
import gc
import logging
import os
import torch
import ctypes
from typing import List, Optional
logger = logging.getLogger(__name__)
# Use relative imports for package structure
from ..experts_base import BaseMoEWrapper
from .loader import (
@@ -534,21 +538,7 @@ class NativeMoEWrapper(BaseMoEWrapper):
)
if NativeMoEWrapper._native_loader_instance is None:
if method == "RAWINT4":
NativeMoEWrapper._native_loader_instance = CompressedSafeTensorLoader(weight_path)
elif method == "FP8":
NativeMoEWrapper._native_loader_instance = FP8SafeTensorLoader(weight_path)
elif method == "FP8_PERCHANNEL":
# Use FP8SafeTensorLoader with per-channel scale format
NativeMoEWrapper._native_loader_instance = FP8SafeTensorLoader(weight_path, scale_suffix="weight_scale")
elif method == "BF16":
NativeMoEWrapper._native_loader_instance = BF16SafeTensorLoader(weight_path)
elif method == "GPTQ_INT4":
NativeMoEWrapper._native_loader_instance = GPTQSafeTensorLoader(weight_path)
elif method == "MXFP4":
NativeMoEWrapper._native_loader_instance = MXFP4SafeTensorLoader(weight_path)
else:
raise NotImplementedError(f"Unsupported method for NativeMoEWrapper: {method}")
NativeMoEWrapper._native_loader_instance = NativeMoEWrapper._create_loader(method, weight_path)
self.loader = NativeMoEWrapper._native_loader_instance
self.gate_weights = None
@@ -558,6 +548,42 @@ class NativeMoEWrapper(BaseMoEWrapper):
self.up_scales = None
self.down_scales = None
@staticmethod
def _create_loader(method: str, weight_path: str):
if method == "RAWINT4":
return CompressedSafeTensorLoader(weight_path)
elif method == "FP8":
return FP8SafeTensorLoader(weight_path)
elif method == "FP8_PERCHANNEL":
return FP8SafeTensorLoader(weight_path, scale_suffix="weight_scale")
elif method == "BF16":
return BF16SafeTensorLoader(weight_path)
elif method == "GPTQ_INT4":
return GPTQSafeTensorLoader(weight_path)
elif method == "MXFP4":
return MXFP4SafeTensorLoader(weight_path)
else:
raise NotImplementedError(f"Unsupported method for NativeMoEWrapper: {method}")
@staticmethod
def _release_loader(layer_idx: int = -1):
if NativeMoEWrapper._native_loader_instance is not None:
NativeMoEWrapper._native_loader_instance.close_all_handles()
NativeMoEWrapper._native_loader_instance = None
if layer_idx >= 0:
logger.info(
"[KT] Released NativeMoEWrapper loader after layer %d: "
"safetensors mmap handles freed.", layer_idx,
)
else:
logger.info(
"[KT] Released NativeMoEWrapper loader: safetensors mmap handles freed."
)
@staticmethod
def force_release_loader():
NativeMoEWrapper._release_loader()
def load_weights_from_tensors(
self,
gate_proj: torch.Tensor,
@@ -570,6 +596,20 @@ class NativeMoEWrapper(BaseMoEWrapper):
def load_weights(self, physical_to_logical_map_cpu: torch.Tensor):
import time
if NativeMoEWrapper._native_loader_instance is None:
t_recreate_start = time.time()
NativeMoEWrapper._native_loader_instance = NativeMoEWrapper._create_loader(
self.method, self.weight_path
)
self.loader = NativeMoEWrapper._native_loader_instance
t_recreate_elapsed = (time.time() - t_recreate_start) * 1000
logger.info(
"[KT] Recreated NativeMoEWrapper loader for layer %d (took %.1fms)",
self.layer_idx, t_recreate_elapsed,
)
else:
self.loader = NativeMoEWrapper._native_loader_instance
t0 = time.time()
base_key = f"model.layers.{self.layer_idx}"
try:
@@ -746,6 +786,8 @@ class NativeMoEWrapper(BaseMoEWrapper):
del self.gate_scales
del self.up_scales
del self.down_scales
NativeMoEWrapper._release_loader(layer_idx=self.layer_idx)
t6 = time.time()
print(

View File

@@ -166,11 +166,15 @@ class SafeTensorLoader:
def close_all_handles(self):
"""Close all file handles and clear the handle map.
Note: safetensors.safe_open doesn't have a close() method,
so we just clear the references and let garbage collection handle cleanup.
Note: safetensors.safe_open doesn't expose a close() method. Releasing
the mmap relies on reference counting: once file_handle_map is cleared
and no tensor holds a reference to the underlying mmap region, the OS
will reclaim the page cache. gc.collect() is called here to trigger
immediate reclamation rather than waiting for the next GC cycle.
"""
# safetensors.safe_open doesn't have close(), just clear references
import gc
self.file_handle_map.clear()
gc.collect()
def load_experts(self, base_key: str, device: str = "cpu"):
"""