Enhance hook mechanism in dumper (#19073)

This commit is contained in:
fzyzcjy
2026-02-22 16:13:38 +08:00
committed by GitHub
parent fdf80b5031
commit cc63c99f11
3 changed files with 109 additions and 31 deletions

View File

@@ -1,3 +1,4 @@
import functools
import json
import os
import re
@@ -90,14 +91,14 @@ class _FrozenConfig(ABC):
class _DumperConfig(_FrozenConfig):
enable: bool = False
filter: Optional[str] = None
dir: str = "/tmp"
dir: str = "/tmp/dumper"
enable_output_file: bool = True
enable_output_console: bool = True
enable_value: bool = True
enable_grad: bool = False
enable_model_value: bool = True
enable_model_grad: bool = True
partial_name: Optional[str] = None
exp_name: Optional[str] = None
enable_http_server: bool = True
cleanup_previous: bool = False
collective_timeout: int = 60
@@ -176,7 +177,7 @@ class _Dumper:
return
# Users may want to `dump` only on some ranks, thus determine name here
self._ensure_partial_name()
self._ensure_exp_name()
self._step += 1
print(f"[Dumper] [{time.time()}] step={self._step}")
@@ -388,7 +389,7 @@ class _Dumper:
save: bool,
step: Optional[int] = None,
) -> None:
self._ensure_partial_name()
self._ensure_exp_name()
self._dump_index += 1
rank = _get_rank()
@@ -399,11 +400,7 @@ class _Dumper:
**tags,
)
full_filename = _format_tags(full_kwargs) + ".pt"
path = (
Path(self._config.dir)
/ f"sglang_dump_{self._config.partial_name}"
/ full_filename
)
path = Path(self._config.dir) / self._config.exp_name / full_filename
if self._config.enable_output_console:
print(
@@ -468,11 +465,13 @@ class _Dumper:
_start_http_server(prefix="/dumper/", target=self, http_port=http_port)
print(f"[Dumper] HTTP server started on port {http_port}")
def _ensure_partial_name(self):
if self._config.partial_name is None:
name = _get_partial_name(timeout_seconds=self._config.collective_timeout)
self.configure(partial_name=name)
print(f"[Dumper] Choose partial_name={name}")
def _ensure_exp_name(self):
if self._config.exp_name is None:
name = _get_default_exp_name(
timeout_seconds=self._config.collective_timeout
)
self.configure(exp_name=name)
print(f"[Dumper] Choose exp_name={name}")
# -------------------------------------- hook dumper ------------------------------------------
@@ -496,11 +495,16 @@ class _NonIntrusiveDumper:
if ctx := self._detect_module_ctx(module_name, module):
self._register_ctx_hooks(module, ctx=ctx)
module.register_forward_hook(
self._make_forward_hook(
module_name=module_name,
is_root=(module_name == ""),
)
is_root = module_name == ""
pre_hook = self._make_forward_pre_hook(
module_name=module_name, is_root=is_root
)
hook = self._make_forward_hook(module_name=module_name, is_root=is_root)
_register_forward_hook_or_replace_fn(
module,
pre_hook=pre_hook,
hook=hook,
mode="replace_fn" if is_root else "hook",
)
@classmethod
@@ -525,11 +529,15 @@ class _NonIntrusiveDumper:
)
)
def _make_forward_hook(self, *, module_name: str, is_root: bool):
def _hook(_module, input, output):
def _make_forward_pre_hook(self, *, module_name: str, is_root: bool):
def _hook(_module, input):
for i, item in enumerate(input):
self._dump_value(module_name, item, role=f"inputs.{i}", is_root=is_root)
return _hook
def _make_forward_hook(self, *, module_name: str, is_root: bool):
def _hook(_module, input, output):
if output is not None:
self._dump_value(module_name, output, role="output", is_root=False)
@@ -566,6 +574,39 @@ class _NonIntrusiveDumper:
return {}
def _register_forward_hook_or_replace_fn(
module: "torch.nn.Module",
*,
pre_hook,
hook,
mode: str,
) -> None:
"""Attach pre/post forward hooks to *module*.
mode="hook" — standard ``register_forward_pre_hook`` / ``register_forward_hook``
(fires only via ``__call__``).
mode="replace_fn" — monkey-patch ``module.forward`` so hooks fire even when
callers invoke ``.forward()`` directly (as sglang does for the
root model).
"""
if mode == "hook":
module.register_forward_pre_hook(pre_hook)
module.register_forward_hook(hook)
elif mode == "replace_fn":
original_forward = module.forward
@functools.wraps(original_forward)
def _wrapped(*args, **kwargs):
pre_hook(module, args)
output = original_forward(*args, **kwargs)
hook(module, args, output)
return output
module.forward = _wrapped
else:
raise ValueError(f"Unknown mode {mode!r}")
# -------------------------------------- util fn ------------------------------------------
@@ -604,14 +645,14 @@ def _collective_with_timeout(fn, operation_name: str, timeout_seconds: int = 60)
completed.set()
def _get_partial_name(timeout_seconds: int = 60):
def _get_default_exp_name(timeout_seconds: int = 60):
rank = _get_rank()
object_list = [str(time.time()) if rank == 0 else None]
object_list = [f"dump_{time.time()}" if rank == 0 else None]
if dist.is_initialized():
_collective_with_timeout(
lambda: dist.broadcast_object_list(object_list, device="cuda"),
operation_name="broadcast_object_list in _get_partial_name",
operation_name="broadcast_object_list in _get_default_exp_name",
timeout_seconds=timeout_seconds,
)
@@ -622,7 +663,7 @@ def _cleanup_old_dumps(base_dir: Path) -> None:
import shutil
if _get_rank() == 0:
for entry in base_dir.glob("sglang_dump_*"):
for entry in base_dir.glob("dump_*"):
if entry.is_dir():
shutil.rmtree(entry)
print(f"[Dumper] Cleaned up {entry}")

View File

@@ -108,7 +108,7 @@ class TestEndToEnd(CustomTestCase):
dumper.step()
dumper.dump("tensor_b", tensor * 2)
dumper.step()
dump_dirs.append(Path(d) / f"sglang_dump_{dumper._config.partial_name}")
dump_dirs.append(Path(d) / dumper._config.exp_name)
args = Namespace(
baseline_path=str(dump_dirs[0]),

View File

@@ -440,7 +440,7 @@ def _make_test_dumper(tmp_path, **overrides) -> _Dumper:
config = _DumperConfig(
enable=True,
dir=str(tmp_path),
partial_name="test",
exp_name="test",
enable_http_server=False,
**overrides,
)
@@ -448,7 +448,7 @@ def _make_test_dumper(tmp_path, **overrides) -> _Dumper:
def _get_filenames(tmpdir):
return {f.name for f in Path(tmpdir).glob("sglang_dump_*/*.pt")}
return {f.name for f in Path(tmpdir).glob("*/*.pt")}
def _assert_files(filenames, *, exist=(), not_exist=()):
@@ -468,7 +468,7 @@ def _load_dump(path: Path) -> dict:
def _find_dump_file(tmpdir, *, rank: int = 0, name: str) -> Path:
matches = [
f
for f in Path(tmpdir).glob("sglang_dump_*/*.pt")
for f in Path(tmpdir).glob("*/*.pt")
if f"rank={rank}" in f.name and name in f.name
]
assert (
@@ -770,7 +770,7 @@ class TestDumpModel:
class TestCleanup:
def test_cleanup_removes_old_dumps(self, tmp_path):
old_dir = tmp_path / "sglang_dump_old"
old_dir = tmp_path / "dump_old"
old_dir.mkdir()
(old_dir / "dummy.pt").touch()
@@ -781,7 +781,7 @@ class TestCleanup:
_assert_files(_get_filenames(tmp_path), exist=["new_tensor"])
def test_no_cleanup_by_default(self, tmp_path):
old_dir = tmp_path / "sglang_dump_old"
old_dir = tmp_path / "dump_old"
old_dir.mkdir()
(old_dir / "dummy.pt").touch()
@@ -1020,6 +1020,43 @@ class TestNonIntrusiveDumper(_NonIntrusiveTestBase):
P = self._PREFIX
assert torch.allclose(captured[f"{P}output"]["value"], output)
def test_inputs_dumped_before_forward(self, tmp_path):
"""Inputs are captured *before* forward(); in-place mutation must not affect them."""
class Mutator(torch.nn.Module):
def forward(self, x):
x.fill_(999.0)
return x
class Inner(torch.nn.Module):
def __init__(self):
super().__init__()
self.mutator = Mutator()
def forward(self, x):
return self.mutator(x)
d = self._make_dumper(tmp_path)
model = self._wrap_as_outer(Inner)
d.register_non_intrusive_dumper(model)
x = torch.randn(2, 4)
original_x = x.clone()
with d.capture_output() as captured:
model(x)
P = self._PREFIX
dumped_input = captured[f"{P}model.mutator.inputs.0"]["value"]
assert torch.allclose(dumped_input, original_x), (
f"pre-hook should capture inputs before forward mutates them; "
f"got {dumped_input} but expected {original_x}"
)
dumped_output = captured[f"{P}model.mutator.output"]["value"]
assert (
dumped_output == 999.0
).all(), "post-hook should capture outputs after forward"
def test_hooks_all_module_levels(self, tmp_path):
class Attention(torch.nn.Module):
def __init__(self):