mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-05-20 04:19:17 +00:00
[fix](kt-kernel): fix double mem used by safetensor loader (#1997)
Release the SafeTensor mmap loader singleton after each layer's load_weights() completes. The C++ engine already holds a deep copy (cpu_infer.sync() guarantees this), so releasing the mmap handles is safe. The next layer recreates the loader on demand. This halves peak memory usage during model loading (e.g. DSv3.2: 1.2T -> 613G). Based on #1966 by @poryfly — adapted to v0.6.2.post3 codebase (adds MXFP4 support missing from the original PR). Co-authored-by: xiongchenhui <xiongchenhui@hisense.com>
This commit is contained in:
@@ -1,8 +1,12 @@
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
import torch
|
||||
import ctypes
|
||||
from typing import List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Use relative imports for package structure
|
||||
from ..experts_base import BaseMoEWrapper
|
||||
from .loader import (
|
||||
@@ -534,21 +538,7 @@ class NativeMoEWrapper(BaseMoEWrapper):
|
||||
)
|
||||
|
||||
if NativeMoEWrapper._native_loader_instance is None:
|
||||
if method == "RAWINT4":
|
||||
NativeMoEWrapper._native_loader_instance = CompressedSafeTensorLoader(weight_path)
|
||||
elif method == "FP8":
|
||||
NativeMoEWrapper._native_loader_instance = FP8SafeTensorLoader(weight_path)
|
||||
elif method == "FP8_PERCHANNEL":
|
||||
# Use FP8SafeTensorLoader with per-channel scale format
|
||||
NativeMoEWrapper._native_loader_instance = FP8SafeTensorLoader(weight_path, scale_suffix="weight_scale")
|
||||
elif method == "BF16":
|
||||
NativeMoEWrapper._native_loader_instance = BF16SafeTensorLoader(weight_path)
|
||||
elif method == "GPTQ_INT4":
|
||||
NativeMoEWrapper._native_loader_instance = GPTQSafeTensorLoader(weight_path)
|
||||
elif method == "MXFP4":
|
||||
NativeMoEWrapper._native_loader_instance = MXFP4SafeTensorLoader(weight_path)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported method for NativeMoEWrapper: {method}")
|
||||
NativeMoEWrapper._native_loader_instance = NativeMoEWrapper._create_loader(method, weight_path)
|
||||
self.loader = NativeMoEWrapper._native_loader_instance
|
||||
|
||||
self.gate_weights = None
|
||||
@@ -558,6 +548,42 @@ class NativeMoEWrapper(BaseMoEWrapper):
|
||||
self.up_scales = None
|
||||
self.down_scales = None
|
||||
|
||||
@staticmethod
|
||||
def _create_loader(method: str, weight_path: str):
|
||||
if method == "RAWINT4":
|
||||
return CompressedSafeTensorLoader(weight_path)
|
||||
elif method == "FP8":
|
||||
return FP8SafeTensorLoader(weight_path)
|
||||
elif method == "FP8_PERCHANNEL":
|
||||
return FP8SafeTensorLoader(weight_path, scale_suffix="weight_scale")
|
||||
elif method == "BF16":
|
||||
return BF16SafeTensorLoader(weight_path)
|
||||
elif method == "GPTQ_INT4":
|
||||
return GPTQSafeTensorLoader(weight_path)
|
||||
elif method == "MXFP4":
|
||||
return MXFP4SafeTensorLoader(weight_path)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported method for NativeMoEWrapper: {method}")
|
||||
|
||||
@staticmethod
|
||||
def _release_loader(layer_idx: int = -1):
|
||||
if NativeMoEWrapper._native_loader_instance is not None:
|
||||
NativeMoEWrapper._native_loader_instance.close_all_handles()
|
||||
NativeMoEWrapper._native_loader_instance = None
|
||||
if layer_idx >= 0:
|
||||
logger.info(
|
||||
"[KT] Released NativeMoEWrapper loader after layer %d: "
|
||||
"safetensors mmap handles freed.", layer_idx,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"[KT] Released NativeMoEWrapper loader: safetensors mmap handles freed."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def force_release_loader():
|
||||
NativeMoEWrapper._release_loader()
|
||||
|
||||
def load_weights_from_tensors(
|
||||
self,
|
||||
gate_proj: torch.Tensor,
|
||||
@@ -570,6 +596,20 @@ class NativeMoEWrapper(BaseMoEWrapper):
|
||||
def load_weights(self, physical_to_logical_map_cpu: torch.Tensor):
|
||||
import time
|
||||
|
||||
if NativeMoEWrapper._native_loader_instance is None:
|
||||
t_recreate_start = time.time()
|
||||
NativeMoEWrapper._native_loader_instance = NativeMoEWrapper._create_loader(
|
||||
self.method, self.weight_path
|
||||
)
|
||||
self.loader = NativeMoEWrapper._native_loader_instance
|
||||
t_recreate_elapsed = (time.time() - t_recreate_start) * 1000
|
||||
logger.info(
|
||||
"[KT] Recreated NativeMoEWrapper loader for layer %d (took %.1fms)",
|
||||
self.layer_idx, t_recreate_elapsed,
|
||||
)
|
||||
else:
|
||||
self.loader = NativeMoEWrapper._native_loader_instance
|
||||
|
||||
t0 = time.time()
|
||||
base_key = f"model.layers.{self.layer_idx}"
|
||||
try:
|
||||
@@ -746,6 +786,8 @@ class NativeMoEWrapper(BaseMoEWrapper):
|
||||
del self.gate_scales
|
||||
del self.up_scales
|
||||
del self.down_scales
|
||||
|
||||
NativeMoEWrapper._release_loader(layer_idx=self.layer_idx)
|
||||
t6 = time.time()
|
||||
|
||||
print(
|
||||
|
||||
@@ -166,11 +166,15 @@ class SafeTensorLoader:
|
||||
def close_all_handles(self):
|
||||
"""Close all file handles and clear the handle map.
|
||||
|
||||
Note: safetensors.safe_open doesn't have a close() method,
|
||||
so we just clear the references and let garbage collection handle cleanup.
|
||||
Note: safetensors.safe_open doesn't expose a close() method. Releasing
|
||||
the mmap relies on reference counting: once file_handle_map is cleared
|
||||
and no tensor holds a reference to the underlying mmap region, the OS
|
||||
will reclaim the page cache. gc.collect() is called here to trigger
|
||||
immediate reclamation rather than waiting for the next GC cycle.
|
||||
"""
|
||||
# safetensors.safe_open doesn't have close(), just clear references
|
||||
import gc
|
||||
self.file_handle_map.clear()
|
||||
gc.collect()
|
||||
|
||||
def load_experts(self, base_key: str, device: str = "cpu"):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user