Enhance hook mechanism in dumper (#19073)

This commit is contained in:
fzyzcjy
2026-02-22 16:13:38 +08:00
committed by GitHub
parent fdf80b5031
commit cc63c99f11
3 changed files with 109 additions and 31 deletions

View File

@@ -1,3 +1,4 @@
import functools
import json
import os
import re
@@ -90,14 +91,14 @@ class _FrozenConfig(ABC):
class _DumperConfig(_FrozenConfig):
enable: bool = False
filter: Optional[str] = None
dir: str = "/tmp"
dir: str = "/tmp/dumper"
enable_output_file: bool = True
enable_output_console: bool = True
enable_value: bool = True
enable_grad: bool = False
enable_model_value: bool = True
enable_model_grad: bool = True
partial_name: Optional[str] = None
exp_name: Optional[str] = None
enable_http_server: bool = True
cleanup_previous: bool = False
collective_timeout: int = 60
@@ -176,7 +177,7 @@ class _Dumper:
return
# Users may want to `dump` only on some ranks, thus determine name here
self._ensure_partial_name()
self._ensure_exp_name()
self._step += 1
print(f"[Dumper] [{time.time()}] step={self._step}")
@@ -388,7 +389,7 @@ class _Dumper:
save: bool,
step: Optional[int] = None,
) -> None:
self._ensure_partial_name()
self._ensure_exp_name()
self._dump_index += 1
rank = _get_rank()
@@ -399,11 +400,7 @@ class _Dumper:
**tags,
)
full_filename = _format_tags(full_kwargs) + ".pt"
path = (
Path(self._config.dir)
/ f"sglang_dump_{self._config.partial_name}"
/ full_filename
)
path = Path(self._config.dir) / self._config.exp_name / full_filename
if self._config.enable_output_console:
print(
@@ -468,11 +465,13 @@ class _Dumper:
_start_http_server(prefix="/dumper/", target=self, http_port=http_port)
print(f"[Dumper] HTTP server started on port {http_port}")
def _ensure_partial_name(self):
if self._config.partial_name is None:
name = _get_partial_name(timeout_seconds=self._config.collective_timeout)
self.configure(partial_name=name)
print(f"[Dumper] Choose partial_name={name}")
def _ensure_exp_name(self):
if self._config.exp_name is None:
name = _get_default_exp_name(
timeout_seconds=self._config.collective_timeout
)
self.configure(exp_name=name)
print(f"[Dumper] Choose exp_name={name}")
# -------------------------------------- hook dumper ------------------------------------------
@@ -496,11 +495,16 @@ class _NonIntrusiveDumper:
if ctx := self._detect_module_ctx(module_name, module):
self._register_ctx_hooks(module, ctx=ctx)
module.register_forward_hook(
self._make_forward_hook(
module_name=module_name,
is_root=(module_name == ""),
)
is_root = module_name == ""
pre_hook = self._make_forward_pre_hook(
module_name=module_name, is_root=is_root
)
hook = self._make_forward_hook(module_name=module_name, is_root=is_root)
_register_forward_hook_or_replace_fn(
module,
pre_hook=pre_hook,
hook=hook,
mode="replace_fn" if is_root else "hook",
)
@classmethod
@@ -525,11 +529,15 @@ class _NonIntrusiveDumper:
)
)
def _make_forward_hook(self, *, module_name: str, is_root: bool):
def _hook(_module, input, output):
def _make_forward_pre_hook(self, *, module_name: str, is_root: bool):
def _hook(_module, input):
for i, item in enumerate(input):
self._dump_value(module_name, item, role=f"inputs.{i}", is_root=is_root)
return _hook
def _make_forward_hook(self, *, module_name: str, is_root: bool):
def _hook(_module, input, output):
if output is not None:
self._dump_value(module_name, output, role="output", is_root=False)
@@ -566,6 +574,39 @@ class _NonIntrusiveDumper:
return {}
def _register_forward_hook_or_replace_fn(
module: "torch.nn.Module",
*,
pre_hook,
hook,
mode: str,
) -> None:
"""Attach pre/post forward hooks to *module*.
mode="hook" — standard ``register_forward_pre_hook`` / ``register_forward_hook``
(fires only via ``__call__``).
mode="replace_fn" — monkey-patch ``module.forward`` so hooks fire even when
callers invoke ``.forward()`` directly (as sglang does for the
root model).
"""
if mode == "hook":
module.register_forward_pre_hook(pre_hook)
module.register_forward_hook(hook)
elif mode == "replace_fn":
original_forward = module.forward
@functools.wraps(original_forward)
def _wrapped(*args, **kwargs):
pre_hook(module, args)
output = original_forward(*args, **kwargs)
hook(module, args, output)
return output
module.forward = _wrapped
else:
raise ValueError(f"Unknown mode {mode!r}")
# -------------------------------------- util fn ------------------------------------------
@@ -604,14 +645,14 @@ def _collective_with_timeout(fn, operation_name: str, timeout_seconds: int = 60)
completed.set()
def _get_partial_name(timeout_seconds: int = 60):
def _get_default_exp_name(timeout_seconds: int = 60):
rank = _get_rank()
object_list = [str(time.time()) if rank == 0 else None]
object_list = [f"dump_{time.time()}" if rank == 0 else None]
if dist.is_initialized():
_collective_with_timeout(
lambda: dist.broadcast_object_list(object_list, device="cuda"),
operation_name="broadcast_object_list in _get_partial_name",
operation_name="broadcast_object_list in _get_default_exp_name",
timeout_seconds=timeout_seconds,
)
@@ -622,7 +663,7 @@ def _cleanup_old_dumps(base_dir: Path) -> None:
import shutil
if _get_rank() == 0:
for entry in base_dir.glob("sglang_dump_*"):
for entry in base_dir.glob("dump_*"):
if entry.is_dir():
shutil.rmtree(entry)
print(f"[Dumper] Cleaned up {entry}")