From cc63c99f112f781605017db018fced686cdc94de Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Sun, 22 Feb 2026 16:13:38 +0800 Subject: [PATCH] Enhance hook mechanism in dumper (#19073) --- python/sglang/srt/debug_utils/dumper.py | 91 ++++++++++++++----- .../debug_utils/test_dump_comparator.py | 2 +- test/registered/debug_utils/test_dumper.py | 47 +++++++++- 3 files changed, 109 insertions(+), 31 deletions(-) diff --git a/python/sglang/srt/debug_utils/dumper.py b/python/sglang/srt/debug_utils/dumper.py index d4cafb1fc..bdfd87bc5 100644 --- a/python/sglang/srt/debug_utils/dumper.py +++ b/python/sglang/srt/debug_utils/dumper.py @@ -1,3 +1,4 @@ +import functools import json import os import re @@ -90,14 +91,14 @@ class _FrozenConfig(ABC): class _DumperConfig(_FrozenConfig): enable: bool = False filter: Optional[str] = None - dir: str = "/tmp" + dir: str = "/tmp/dumper" enable_output_file: bool = True enable_output_console: bool = True enable_value: bool = True enable_grad: bool = False enable_model_value: bool = True enable_model_grad: bool = True - partial_name: Optional[str] = None + exp_name: Optional[str] = None enable_http_server: bool = True cleanup_previous: bool = False collective_timeout: int = 60 @@ -176,7 +177,7 @@ class _Dumper: return # Users may want to `dump` only on some ranks, thus determine name here - self._ensure_partial_name() + self._ensure_exp_name() self._step += 1 print(f"[Dumper] [{time.time()}] step={self._step}") @@ -388,7 +389,7 @@ class _Dumper: save: bool, step: Optional[int] = None, ) -> None: - self._ensure_partial_name() + self._ensure_exp_name() self._dump_index += 1 rank = _get_rank() @@ -399,11 +400,7 @@ class _Dumper: **tags, ) full_filename = _format_tags(full_kwargs) + ".pt" - path = ( - Path(self._config.dir) - / f"sglang_dump_{self._config.partial_name}" - / full_filename - ) + path = Path(self._config.dir) / self._config.exp_name / full_filename if self._config.enable_output_console: print( @@ -468,11 +465,13 @@ class _Dumper: _start_http_server(prefix="/dumper/", target=self, http_port=http_port) print(f"[Dumper] HTTP server started on port {http_port}") - def _ensure_partial_name(self): - if self._config.partial_name is None: - name = _get_partial_name(timeout_seconds=self._config.collective_timeout) - self.configure(partial_name=name) - print(f"[Dumper] Choose partial_name={name}") + def _ensure_exp_name(self): + if self._config.exp_name is None: + name = _get_default_exp_name( + timeout_seconds=self._config.collective_timeout + ) + self.configure(exp_name=name) + print(f"[Dumper] Choose exp_name={name}") # -------------------------------------- hook dumper ------------------------------------------ @@ -496,11 +495,16 @@ class _NonIntrusiveDumper: if ctx := self._detect_module_ctx(module_name, module): self._register_ctx_hooks(module, ctx=ctx) - module.register_forward_hook( - self._make_forward_hook( - module_name=module_name, - is_root=(module_name == ""), - ) + is_root = module_name == "" + pre_hook = self._make_forward_pre_hook( + module_name=module_name, is_root=is_root + ) + hook = self._make_forward_hook(module_name=module_name, is_root=is_root) + _register_forward_hook_or_replace_fn( + module, + pre_hook=pre_hook, + hook=hook, + mode="replace_fn" if is_root else "hook", ) @classmethod @@ -525,11 +529,15 @@ class _NonIntrusiveDumper: ) ) - def _make_forward_hook(self, *, module_name: str, is_root: bool): - def _hook(_module, input, output): + def _make_forward_pre_hook(self, *, module_name: str, is_root: bool): + def _hook(_module, input): for i, item in enumerate(input): self._dump_value(module_name, item, role=f"inputs.{i}", is_root=is_root) + return _hook + + def _make_forward_hook(self, *, module_name: str, is_root: bool): + def _hook(_module, input, output): if output is not None: self._dump_value(module_name, output, role="output", is_root=False) @@ -566,6 +574,39 @@ class _NonIntrusiveDumper: return {} +def _register_forward_hook_or_replace_fn( + module: "torch.nn.Module", + *, + pre_hook, + hook, + mode: str, +) -> None: + """Attach pre/post forward hooks to *module*. + + mode="hook" — standard ``register_forward_pre_hook`` / ``register_forward_hook`` + (fires only via ``__call__``). + mode="replace_fn" — monkey-patch ``module.forward`` so hooks fire even when + callers invoke ``.forward()`` directly (as sglang does for the + root model). + """ + if mode == "hook": + module.register_forward_pre_hook(pre_hook) + module.register_forward_hook(hook) + elif mode == "replace_fn": + original_forward = module.forward + + @functools.wraps(original_forward) + def _wrapped(*args, **kwargs): + pre_hook(module, args) + output = original_forward(*args, **kwargs) + hook(module, args, output) + return output + + module.forward = _wrapped + else: + raise ValueError(f"Unknown mode {mode!r}") + + # -------------------------------------- util fn ------------------------------------------ @@ -604,14 +645,14 @@ def _collective_with_timeout(fn, operation_name: str, timeout_seconds: int = 60) completed.set() -def _get_partial_name(timeout_seconds: int = 60): +def _get_default_exp_name(timeout_seconds: int = 60): rank = _get_rank() - object_list = [str(time.time()) if rank == 0 else None] + object_list = [f"dump_{time.time()}" if rank == 0 else None] if dist.is_initialized(): _collective_with_timeout( lambda: dist.broadcast_object_list(object_list, device="cuda"), - operation_name="broadcast_object_list in _get_partial_name", + operation_name="broadcast_object_list in _get_default_exp_name", timeout_seconds=timeout_seconds, ) @@ -622,7 +663,7 @@ def _cleanup_old_dumps(base_dir: Path) -> None: import shutil if _get_rank() == 0: - for entry in base_dir.glob("sglang_dump_*"): + for entry in base_dir.glob("dump_*"): if entry.is_dir(): shutil.rmtree(entry) print(f"[Dumper] Cleaned up {entry}") diff --git a/test/registered/debug_utils/test_dump_comparator.py b/test/registered/debug_utils/test_dump_comparator.py index 9fe95e58d..7b0fffd81 100644 --- a/test/registered/debug_utils/test_dump_comparator.py +++ b/test/registered/debug_utils/test_dump_comparator.py @@ -108,7 +108,7 @@ class TestEndToEnd(CustomTestCase): dumper.step() dumper.dump("tensor_b", tensor * 2) dumper.step() - dump_dirs.append(Path(d) / f"sglang_dump_{dumper._config.partial_name}") + dump_dirs.append(Path(d) / dumper._config.exp_name) args = Namespace( baseline_path=str(dump_dirs[0]), diff --git a/test/registered/debug_utils/test_dumper.py b/test/registered/debug_utils/test_dumper.py index b96ece72a..1e690fe76 100644 --- a/test/registered/debug_utils/test_dumper.py +++ b/test/registered/debug_utils/test_dumper.py @@ -440,7 +440,7 @@ def _make_test_dumper(tmp_path, **overrides) -> _Dumper: config = _DumperConfig( enable=True, dir=str(tmp_path), - partial_name="test", + exp_name="test", enable_http_server=False, **overrides, ) @@ -448,7 +448,7 @@ def _make_test_dumper(tmp_path, **overrides) -> _Dumper: def _get_filenames(tmpdir): - return {f.name for f in Path(tmpdir).glob("sglang_dump_*/*.pt")} + return {f.name for f in Path(tmpdir).glob("*/*.pt")} def _assert_files(filenames, *, exist=(), not_exist=()): @@ -468,7 +468,7 @@ def _load_dump(path: Path) -> dict: def _find_dump_file(tmpdir, *, rank: int = 0, name: str) -> Path: matches = [ f - for f in Path(tmpdir).glob("sglang_dump_*/*.pt") + for f in Path(tmpdir).glob("*/*.pt") if f"rank={rank}" in f.name and name in f.name ] assert ( @@ -770,7 +770,7 @@ class TestDumpModel: class TestCleanup: def test_cleanup_removes_old_dumps(self, tmp_path): - old_dir = tmp_path / "sglang_dump_old" + old_dir = tmp_path / "dump_old" old_dir.mkdir() (old_dir / "dummy.pt").touch() @@ -781,7 +781,7 @@ class TestCleanup: _assert_files(_get_filenames(tmp_path), exist=["new_tensor"]) def test_no_cleanup_by_default(self, tmp_path): - old_dir = tmp_path / "sglang_dump_old" + old_dir = tmp_path / "dump_old" old_dir.mkdir() (old_dir / "dummy.pt").touch() @@ -1020,6 +1020,43 @@ class TestNonIntrusiveDumper(_NonIntrusiveTestBase): P = self._PREFIX assert torch.allclose(captured[f"{P}output"]["value"], output) + def test_inputs_dumped_before_forward(self, tmp_path): + """Inputs are captured *before* forward(); in-place mutation must not affect them.""" + + class Mutator(torch.nn.Module): + def forward(self, x): + x.fill_(999.0) + return x + + class Inner(torch.nn.Module): + def __init__(self): + super().__init__() + self.mutator = Mutator() + + def forward(self, x): + return self.mutator(x) + + d = self._make_dumper(tmp_path) + model = self._wrap_as_outer(Inner) + d.register_non_intrusive_dumper(model) + + x = torch.randn(2, 4) + original_x = x.clone() + with d.capture_output() as captured: + model(x) + + P = self._PREFIX + dumped_input = captured[f"{P}model.mutator.inputs.0"]["value"] + assert torch.allclose(dumped_input, original_x), ( + f"pre-hook should capture inputs before forward mutates them; " + f"got {dumped_input} but expected {original_x}" + ) + + dumped_output = captured[f"{P}model.mutator.output"]["value"] + assert ( + dumped_output == 999.0 + ).all(), "post-hook should capture outputs after forward" + def test_hooks_all_module_levels(self, tmp_path): class Attention(torch.nn.Module): def __init__(self):