diff --git a/kt-kernel/python/utils/amx.py b/kt-kernel/python/utils/amx.py index 955db324..c1825a70 100644 --- a/kt-kernel/python/utils/amx.py +++ b/kt-kernel/python/utils/amx.py @@ -1,8 +1,12 @@ +import gc +import logging import os import torch import ctypes from typing import List, Optional +logger = logging.getLogger(__name__) + # Use relative imports for package structure from ..experts_base import BaseMoEWrapper from .loader import ( @@ -534,21 +538,7 @@ class NativeMoEWrapper(BaseMoEWrapper): ) if NativeMoEWrapper._native_loader_instance is None: - if method == "RAWINT4": - NativeMoEWrapper._native_loader_instance = CompressedSafeTensorLoader(weight_path) - elif method == "FP8": - NativeMoEWrapper._native_loader_instance = FP8SafeTensorLoader(weight_path) - elif method == "FP8_PERCHANNEL": - # Use FP8SafeTensorLoader with per-channel scale format - NativeMoEWrapper._native_loader_instance = FP8SafeTensorLoader(weight_path, scale_suffix="weight_scale") - elif method == "BF16": - NativeMoEWrapper._native_loader_instance = BF16SafeTensorLoader(weight_path) - elif method == "GPTQ_INT4": - NativeMoEWrapper._native_loader_instance = GPTQSafeTensorLoader(weight_path) - elif method == "MXFP4": - NativeMoEWrapper._native_loader_instance = MXFP4SafeTensorLoader(weight_path) - else: - raise NotImplementedError(f"Unsupported method for NativeMoEWrapper: {method}") + NativeMoEWrapper._native_loader_instance = NativeMoEWrapper._create_loader(method, weight_path) self.loader = NativeMoEWrapper._native_loader_instance self.gate_weights = None @@ -558,6 +548,42 @@ class NativeMoEWrapper(BaseMoEWrapper): self.up_scales = None self.down_scales = None + @staticmethod + def _create_loader(method: str, weight_path: str): + if method == "RAWINT4": + return CompressedSafeTensorLoader(weight_path) + elif method == "FP8": + return FP8SafeTensorLoader(weight_path) + elif method == "FP8_PERCHANNEL": + return FP8SafeTensorLoader(weight_path, scale_suffix="weight_scale") + elif method == "BF16": + return BF16SafeTensorLoader(weight_path) + elif method == "GPTQ_INT4": + return GPTQSafeTensorLoader(weight_path) + elif method == "MXFP4": + return MXFP4SafeTensorLoader(weight_path) + else: + raise NotImplementedError(f"Unsupported method for NativeMoEWrapper: {method}") + + @staticmethod + def _release_loader(layer_idx: int = -1): + if NativeMoEWrapper._native_loader_instance is not None: + NativeMoEWrapper._native_loader_instance.close_all_handles() + NativeMoEWrapper._native_loader_instance = None + if layer_idx >= 0: + logger.info( + "[KT] Released NativeMoEWrapper loader after layer %d: " + "safetensors mmap handles freed.", layer_idx, + ) + else: + logger.info( + "[KT] Released NativeMoEWrapper loader: safetensors mmap handles freed." + ) + + @staticmethod + def force_release_loader(): + NativeMoEWrapper._release_loader() + def load_weights_from_tensors( self, gate_proj: torch.Tensor, @@ -570,6 +596,20 @@ class NativeMoEWrapper(BaseMoEWrapper): def load_weights(self, physical_to_logical_map_cpu: torch.Tensor): import time + if NativeMoEWrapper._native_loader_instance is None: + t_recreate_start = time.time() + NativeMoEWrapper._native_loader_instance = NativeMoEWrapper._create_loader( + self.method, self.weight_path + ) + self.loader = NativeMoEWrapper._native_loader_instance + t_recreate_elapsed = (time.time() - t_recreate_start) * 1000 + logger.info( + "[KT] Recreated NativeMoEWrapper loader for layer %d (took %.1fms)", + self.layer_idx, t_recreate_elapsed, + ) + else: + self.loader = NativeMoEWrapper._native_loader_instance + t0 = time.time() base_key = f"model.layers.{self.layer_idx}" try: @@ -746,6 +786,8 @@ class NativeMoEWrapper(BaseMoEWrapper): del self.gate_scales del self.up_scales del self.down_scales + + NativeMoEWrapper._release_loader(layer_idx=self.layer_idx) t6 = time.time() print( diff --git a/kt-kernel/python/utils/loader.py b/kt-kernel/python/utils/loader.py index ed4fd2b1..43a4b5a7 100644 --- a/kt-kernel/python/utils/loader.py +++ b/kt-kernel/python/utils/loader.py @@ -166,11 +166,15 @@ class SafeTensorLoader: def close_all_handles(self): """Close all file handles and clear the handle map. - Note: safetensors.safe_open doesn't have a close() method, - so we just clear the references and let garbage collection handle cleanup. + Note: safetensors.safe_open doesn't expose a close() method. Releasing + the mmap relies on reference counting: once file_handle_map is cleared + and no tensor holds a reference to the underlying mmap region, the OS + will reclaim the page cache. gc.collect() is called here to trigger + immediate reclamation rather than waiting for the next GC cycle. """ - # safetensors.safe_open doesn't have close(), just clear references + import gc self.file_handle_map.clear() + gc.collect() def load_experts(self, base_key: str, device: str = "cpu"): """ diff --git a/kt-kernel/test/test_native_moe_loader_auto_release.py b/kt-kernel/test/test_native_moe_loader_auto_release.py new file mode 100644 index 00000000..6352fb91 --- /dev/null +++ b/kt-kernel/test/test_native_moe_loader_auto_release.py @@ -0,0 +1,204 @@ +"""Tests for NativeMoEWrapper layerwise mmap-release mechanism. + +Verifies that the SafeTensor loader singleton (_native_loader_instance) is +released after EACH layer's load_weights() completes (not just after all layers), +and that the loader is recreated on demand for the next layer. + +These tests use mocking so they can run without actual safetensors files or +compiled kt_kernel_ext binaries. +""" + +import sys +import os +import types +import unittest +from unittest.mock import MagicMock, patch + + +class MockLoader: + """Minimal mock SafeTensorLoader.""" + + _create_count = 0 # Track how many times a loader was created + + def __init__(self): + MockLoader._create_count += 1 + self.closed = False + self.file_handle_map = {"dummy.safetensors": object()} + + def close_all_handles(self): + self.closed = True + self.file_handle_map.clear() + + +class FakeNativeMoEWrapper: + """ + A simplified replica of NativeMoEWrapper that isolates the + layerwise-release + recreate logic without requiring kt_kernel_ext. + """ + + _native_loader_instance = None + # Simulate _create_loader: returns a fresh MockLoader each time + _create_loader_calls = 0 + + def __init__(self, layer_idx=0): + self.layer_idx = layer_idx + self.method = "FP8" + self.weight_path = "/fake/path" + + def _ensure_loader(self): + """Simulate the loader-recreate logic at the start of load_weights.""" + if FakeNativeMoEWrapper._native_loader_instance is None: + FakeNativeMoEWrapper._create_loader_calls += 1 + FakeNativeMoEWrapper._native_loader_instance = MockLoader() + self.loader = FakeNativeMoEWrapper._native_loader_instance + + def load_weights(self): + """Simulate load_weights: ensure loader -> do work -> release loader.""" + self._ensure_loader() + # Simulate: C++ sync + del Python tensors -> release + FakeNativeMoEWrapper._release_loader(layer_idx=self.layer_idx) + + @staticmethod + def _release_loader(layer_idx=-1): + if FakeNativeMoEWrapper._native_loader_instance is not None: + FakeNativeMoEWrapper._native_loader_instance.close_all_handles() + FakeNativeMoEWrapper._native_loader_instance = None + + @staticmethod + def force_release_loader(): + FakeNativeMoEWrapper._release_loader() + + +def _reset_state(): + """Reset all test state between tests.""" + FakeNativeMoEWrapper._native_loader_instance = None + FakeNativeMoEWrapper._create_loader_calls = 0 + MockLoader._create_count = 0 + + +# --------------------------------------------------------------------------- +# Test cases +# --------------------------------------------------------------------------- + +class TestLayerwiseRelease(unittest.TestCase): + """Each layer's load_weights() should release the loader afterwards.""" + + def setUp(self): + _reset_state() + + def test_single_layer_released_after_load(self): + w = FakeNativeMoEWrapper(layer_idx=0) + w.load_weights() + self.assertIsNone(FakeNativeMoEWrapper._native_loader_instance, + "Loader should be None after single layer loads") + + def test_each_layer_releases_loader(self): + """After every layer's load_weights(), the loader should be None.""" + for i in range(5): + w = FakeNativeMoEWrapper(layer_idx=i) + w.load_weights() + self.assertIsNone( + FakeNativeMoEWrapper._native_loader_instance, + f"Loader should be None after layer {i} loads", + ) + + def test_loader_recreated_for_each_layer(self): + """Each layer should trigger a loader recreation (since previous layer released it).""" + N = 4 + for i in range(N): + w = FakeNativeMoEWrapper(layer_idx=i) + w.load_weights() + + # First layer uses the initial loader; layers 1..N-1 recreate it + # Total recreations = N - 1 (layer 0 doesn't recreate if loader pre-existed, + # but in this test the loader starts as None so all N layers recreate) + self.assertEqual( + FakeNativeMoEWrapper._create_loader_calls, N, + f"Expected {N} loader recreations for {N} layers, " + f"got {FakeNativeMoEWrapper._create_loader_calls}", + ) + + +class TestLoaderRecreate(unittest.TestCase): + """Loader should be recreated when _native_loader_instance is None.""" + + def setUp(self): + _reset_state() + + def test_first_layer_creates_loader(self): + w = FakeNativeMoEWrapper(layer_idx=0) + w.load_weights() + # Loader was created (then released), but the creation happened + self.assertGreater(FakeNativeMoEWrapper._create_loader_calls, 0) + + def test_second_layer_recreates_after_first_released(self): + w0 = FakeNativeMoEWrapper(layer_idx=0) + w0.load_weights() + self.assertIsNone(FakeNativeMoEWrapper._native_loader_instance) + + w1 = FakeNativeMoEWrapper(layer_idx=1) + w1.load_weights() + # Second layer should have recreated the loader + self.assertEqual(FakeNativeMoEWrapper._create_loader_calls, 2, + "Both layers should recreate the loader") + + def test_pre_existing_loader_not_recreated(self): + """If loader already exists (e.g., from __init__), it should not be recreated.""" + loader = MockLoader() + FakeNativeMoEWrapper._native_loader_instance = loader + initial_create_calls = FakeNativeMoEWrapper._create_loader_calls + + w = FakeNativeMoEWrapper(layer_idx=0) + w._ensure_loader() + # No new creation should happen + self.assertEqual(FakeNativeMoEWrapper._create_loader_calls, initial_create_calls) + self.assertIs(w.loader, loader) + + +class TestForceReleaseLoader(unittest.TestCase): + """force_release_loader() should work at any time.""" + + def setUp(self): + _reset_state() + + def test_force_release_before_any_load(self): + loader = MockLoader() + FakeNativeMoEWrapper._native_loader_instance = loader + + FakeNativeMoEWrapper.force_release_loader() + + self.assertIsNone(FakeNativeMoEWrapper._native_loader_instance) + self.assertTrue(loader.closed) + + def test_force_release_when_loader_is_none(self): + """force_release_loader() should be safe even if loader is already None.""" + FakeNativeMoEWrapper._native_loader_instance = None + FakeNativeMoEWrapper.force_release_loader() + self.assertIsNone(FakeNativeMoEWrapper._native_loader_instance) + + def test_force_release_mid_loading(self): + loader = MockLoader() + FakeNativeMoEWrapper._native_loader_instance = loader + + # Load first layer + w0 = FakeNativeMoEWrapper(layer_idx=0) + w0.load_weights() + # Loader is now released (each layer releases it) + + # Set a new loader manually + loader2 = MockLoader() + FakeNativeMoEWrapper._native_loader_instance = loader2 + + # Force release before next load + FakeNativeMoEWrapper.force_release_loader() + + self.assertIsNone(FakeNativeMoEWrapper._native_loader_instance) + self.assertTrue(loader2.closed) + + +# --------------------------------------------------------------------------- +# Run directly +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + unittest.main(verbosity=2)