mirror of
https://github.com/kvcache-ai/sglang.git
synced 2026-07-03 13:57:04 +00:00
Enhance hook mechanism in dumper (#19073)
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import functools
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
@@ -90,14 +91,14 @@ class _FrozenConfig(ABC):
|
||||
class _DumperConfig(_FrozenConfig):
|
||||
enable: bool = False
|
||||
filter: Optional[str] = None
|
||||
dir: str = "/tmp"
|
||||
dir: str = "/tmp/dumper"
|
||||
enable_output_file: bool = True
|
||||
enable_output_console: bool = True
|
||||
enable_value: bool = True
|
||||
enable_grad: bool = False
|
||||
enable_model_value: bool = True
|
||||
enable_model_grad: bool = True
|
||||
partial_name: Optional[str] = None
|
||||
exp_name: Optional[str] = None
|
||||
enable_http_server: bool = True
|
||||
cleanup_previous: bool = False
|
||||
collective_timeout: int = 60
|
||||
@@ -176,7 +177,7 @@ class _Dumper:
|
||||
return
|
||||
|
||||
# Users may want to `dump` only on some ranks, thus determine name here
|
||||
self._ensure_partial_name()
|
||||
self._ensure_exp_name()
|
||||
|
||||
self._step += 1
|
||||
print(f"[Dumper] [{time.time()}] step={self._step}")
|
||||
@@ -388,7 +389,7 @@ class _Dumper:
|
||||
save: bool,
|
||||
step: Optional[int] = None,
|
||||
) -> None:
|
||||
self._ensure_partial_name()
|
||||
self._ensure_exp_name()
|
||||
self._dump_index += 1
|
||||
|
||||
rank = _get_rank()
|
||||
@@ -399,11 +400,7 @@ class _Dumper:
|
||||
**tags,
|
||||
)
|
||||
full_filename = _format_tags(full_kwargs) + ".pt"
|
||||
path = (
|
||||
Path(self._config.dir)
|
||||
/ f"sglang_dump_{self._config.partial_name}"
|
||||
/ full_filename
|
||||
)
|
||||
path = Path(self._config.dir) / self._config.exp_name / full_filename
|
||||
|
||||
if self._config.enable_output_console:
|
||||
print(
|
||||
@@ -468,11 +465,13 @@ class _Dumper:
|
||||
_start_http_server(prefix="/dumper/", target=self, http_port=http_port)
|
||||
print(f"[Dumper] HTTP server started on port {http_port}")
|
||||
|
||||
def _ensure_partial_name(self):
|
||||
if self._config.partial_name is None:
|
||||
name = _get_partial_name(timeout_seconds=self._config.collective_timeout)
|
||||
self.configure(partial_name=name)
|
||||
print(f"[Dumper] Choose partial_name={name}")
|
||||
def _ensure_exp_name(self):
|
||||
if self._config.exp_name is None:
|
||||
name = _get_default_exp_name(
|
||||
timeout_seconds=self._config.collective_timeout
|
||||
)
|
||||
self.configure(exp_name=name)
|
||||
print(f"[Dumper] Choose exp_name={name}")
|
||||
|
||||
|
||||
# -------------------------------------- hook dumper ------------------------------------------
|
||||
@@ -496,11 +495,16 @@ class _NonIntrusiveDumper:
|
||||
if ctx := self._detect_module_ctx(module_name, module):
|
||||
self._register_ctx_hooks(module, ctx=ctx)
|
||||
|
||||
module.register_forward_hook(
|
||||
self._make_forward_hook(
|
||||
module_name=module_name,
|
||||
is_root=(module_name == ""),
|
||||
)
|
||||
is_root = module_name == ""
|
||||
pre_hook = self._make_forward_pre_hook(
|
||||
module_name=module_name, is_root=is_root
|
||||
)
|
||||
hook = self._make_forward_hook(module_name=module_name, is_root=is_root)
|
||||
_register_forward_hook_or_replace_fn(
|
||||
module,
|
||||
pre_hook=pre_hook,
|
||||
hook=hook,
|
||||
mode="replace_fn" if is_root else "hook",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -525,11 +529,15 @@ class _NonIntrusiveDumper:
|
||||
)
|
||||
)
|
||||
|
||||
def _make_forward_hook(self, *, module_name: str, is_root: bool):
|
||||
def _hook(_module, input, output):
|
||||
def _make_forward_pre_hook(self, *, module_name: str, is_root: bool):
|
||||
def _hook(_module, input):
|
||||
for i, item in enumerate(input):
|
||||
self._dump_value(module_name, item, role=f"inputs.{i}", is_root=is_root)
|
||||
|
||||
return _hook
|
||||
|
||||
def _make_forward_hook(self, *, module_name: str, is_root: bool):
|
||||
def _hook(_module, input, output):
|
||||
if output is not None:
|
||||
self._dump_value(module_name, output, role="output", is_root=False)
|
||||
|
||||
@@ -566,6 +574,39 @@ class _NonIntrusiveDumper:
|
||||
return {}
|
||||
|
||||
|
||||
def _register_forward_hook_or_replace_fn(
|
||||
module: "torch.nn.Module",
|
||||
*,
|
||||
pre_hook,
|
||||
hook,
|
||||
mode: str,
|
||||
) -> None:
|
||||
"""Attach pre/post forward hooks to *module*.
|
||||
|
||||
mode="hook" — standard ``register_forward_pre_hook`` / ``register_forward_hook``
|
||||
(fires only via ``__call__``).
|
||||
mode="replace_fn" — monkey-patch ``module.forward`` so hooks fire even when
|
||||
callers invoke ``.forward()`` directly (as sglang does for the
|
||||
root model).
|
||||
"""
|
||||
if mode == "hook":
|
||||
module.register_forward_pre_hook(pre_hook)
|
||||
module.register_forward_hook(hook)
|
||||
elif mode == "replace_fn":
|
||||
original_forward = module.forward
|
||||
|
||||
@functools.wraps(original_forward)
|
||||
def _wrapped(*args, **kwargs):
|
||||
pre_hook(module, args)
|
||||
output = original_forward(*args, **kwargs)
|
||||
hook(module, args, output)
|
||||
return output
|
||||
|
||||
module.forward = _wrapped
|
||||
else:
|
||||
raise ValueError(f"Unknown mode {mode!r}")
|
||||
|
||||
|
||||
# -------------------------------------- util fn ------------------------------------------
|
||||
|
||||
|
||||
@@ -604,14 +645,14 @@ def _collective_with_timeout(fn, operation_name: str, timeout_seconds: int = 60)
|
||||
completed.set()
|
||||
|
||||
|
||||
def _get_partial_name(timeout_seconds: int = 60):
|
||||
def _get_default_exp_name(timeout_seconds: int = 60):
|
||||
rank = _get_rank()
|
||||
object_list = [str(time.time()) if rank == 0 else None]
|
||||
object_list = [f"dump_{time.time()}" if rank == 0 else None]
|
||||
|
||||
if dist.is_initialized():
|
||||
_collective_with_timeout(
|
||||
lambda: dist.broadcast_object_list(object_list, device="cuda"),
|
||||
operation_name="broadcast_object_list in _get_partial_name",
|
||||
operation_name="broadcast_object_list in _get_default_exp_name",
|
||||
timeout_seconds=timeout_seconds,
|
||||
)
|
||||
|
||||
@@ -622,7 +663,7 @@ def _cleanup_old_dumps(base_dir: Path) -> None:
|
||||
import shutil
|
||||
|
||||
if _get_rank() == 0:
|
||||
for entry in base_dir.glob("sglang_dump_*"):
|
||||
for entry in base_dir.glob("dump_*"):
|
||||
if entry.is_dir():
|
||||
shutil.rmtree(entry)
|
||||
print(f"[Dumper] Cleaned up {entry}")
|
||||
|
||||
@@ -108,7 +108,7 @@ class TestEndToEnd(CustomTestCase):
|
||||
dumper.step()
|
||||
dumper.dump("tensor_b", tensor * 2)
|
||||
dumper.step()
|
||||
dump_dirs.append(Path(d) / f"sglang_dump_{dumper._config.partial_name}")
|
||||
dump_dirs.append(Path(d) / dumper._config.exp_name)
|
||||
|
||||
args = Namespace(
|
||||
baseline_path=str(dump_dirs[0]),
|
||||
|
||||
@@ -440,7 +440,7 @@ def _make_test_dumper(tmp_path, **overrides) -> _Dumper:
|
||||
config = _DumperConfig(
|
||||
enable=True,
|
||||
dir=str(tmp_path),
|
||||
partial_name="test",
|
||||
exp_name="test",
|
||||
enable_http_server=False,
|
||||
**overrides,
|
||||
)
|
||||
@@ -448,7 +448,7 @@ def _make_test_dumper(tmp_path, **overrides) -> _Dumper:
|
||||
|
||||
|
||||
def _get_filenames(tmpdir):
|
||||
return {f.name for f in Path(tmpdir).glob("sglang_dump_*/*.pt")}
|
||||
return {f.name for f in Path(tmpdir).glob("*/*.pt")}
|
||||
|
||||
|
||||
def _assert_files(filenames, *, exist=(), not_exist=()):
|
||||
@@ -468,7 +468,7 @@ def _load_dump(path: Path) -> dict:
|
||||
def _find_dump_file(tmpdir, *, rank: int = 0, name: str) -> Path:
|
||||
matches = [
|
||||
f
|
||||
for f in Path(tmpdir).glob("sglang_dump_*/*.pt")
|
||||
for f in Path(tmpdir).glob("*/*.pt")
|
||||
if f"rank={rank}" in f.name and name in f.name
|
||||
]
|
||||
assert (
|
||||
@@ -770,7 +770,7 @@ class TestDumpModel:
|
||||
|
||||
class TestCleanup:
|
||||
def test_cleanup_removes_old_dumps(self, tmp_path):
|
||||
old_dir = tmp_path / "sglang_dump_old"
|
||||
old_dir = tmp_path / "dump_old"
|
||||
old_dir.mkdir()
|
||||
(old_dir / "dummy.pt").touch()
|
||||
|
||||
@@ -781,7 +781,7 @@ class TestCleanup:
|
||||
_assert_files(_get_filenames(tmp_path), exist=["new_tensor"])
|
||||
|
||||
def test_no_cleanup_by_default(self, tmp_path):
|
||||
old_dir = tmp_path / "sglang_dump_old"
|
||||
old_dir = tmp_path / "dump_old"
|
||||
old_dir.mkdir()
|
||||
(old_dir / "dummy.pt").touch()
|
||||
|
||||
@@ -1020,6 +1020,43 @@ class TestNonIntrusiveDumper(_NonIntrusiveTestBase):
|
||||
P = self._PREFIX
|
||||
assert torch.allclose(captured[f"{P}output"]["value"], output)
|
||||
|
||||
def test_inputs_dumped_before_forward(self, tmp_path):
|
||||
"""Inputs are captured *before* forward(); in-place mutation must not affect them."""
|
||||
|
||||
class Mutator(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
x.fill_(999.0)
|
||||
return x
|
||||
|
||||
class Inner(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mutator = Mutator()
|
||||
|
||||
def forward(self, x):
|
||||
return self.mutator(x)
|
||||
|
||||
d = self._make_dumper(tmp_path)
|
||||
model = self._wrap_as_outer(Inner)
|
||||
d.register_non_intrusive_dumper(model)
|
||||
|
||||
x = torch.randn(2, 4)
|
||||
original_x = x.clone()
|
||||
with d.capture_output() as captured:
|
||||
model(x)
|
||||
|
||||
P = self._PREFIX
|
||||
dumped_input = captured[f"{P}model.mutator.inputs.0"]["value"]
|
||||
assert torch.allclose(dumped_input, original_x), (
|
||||
f"pre-hook should capture inputs before forward mutates them; "
|
||||
f"got {dumped_input} but expected {original_x}"
|
||||
)
|
||||
|
||||
dumped_output = captured[f"{P}model.mutator.output"]["value"]
|
||||
assert (
|
||||
dumped_output == 999.0
|
||||
).all(), "post-hook should capture outputs after forward"
|
||||
|
||||
def test_hooks_all_module_levels(self, tmp_path):
|
||||
class Attention(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
||||
Reference in New Issue
Block a user