mirror of
https://github.com/kvcache-ai/sglang.git
synced 2026-07-03 05:47:24 +00:00
Enhance hook mechanism in dumper (#19073)
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import functools
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
@@ -90,14 +91,14 @@ class _FrozenConfig(ABC):
|
||||
class _DumperConfig(_FrozenConfig):
|
||||
enable: bool = False
|
||||
filter: Optional[str] = None
|
||||
dir: str = "/tmp"
|
||||
dir: str = "/tmp/dumper"
|
||||
enable_output_file: bool = True
|
||||
enable_output_console: bool = True
|
||||
enable_value: bool = True
|
||||
enable_grad: bool = False
|
||||
enable_model_value: bool = True
|
||||
enable_model_grad: bool = True
|
||||
partial_name: Optional[str] = None
|
||||
exp_name: Optional[str] = None
|
||||
enable_http_server: bool = True
|
||||
cleanup_previous: bool = False
|
||||
collective_timeout: int = 60
|
||||
@@ -176,7 +177,7 @@ class _Dumper:
|
||||
return
|
||||
|
||||
# Users may want to `dump` only on some ranks, thus determine name here
|
||||
self._ensure_partial_name()
|
||||
self._ensure_exp_name()
|
||||
|
||||
self._step += 1
|
||||
print(f"[Dumper] [{time.time()}] step={self._step}")
|
||||
@@ -388,7 +389,7 @@ class _Dumper:
|
||||
save: bool,
|
||||
step: Optional[int] = None,
|
||||
) -> None:
|
||||
self._ensure_partial_name()
|
||||
self._ensure_exp_name()
|
||||
self._dump_index += 1
|
||||
|
||||
rank = _get_rank()
|
||||
@@ -399,11 +400,7 @@ class _Dumper:
|
||||
**tags,
|
||||
)
|
||||
full_filename = _format_tags(full_kwargs) + ".pt"
|
||||
path = (
|
||||
Path(self._config.dir)
|
||||
/ f"sglang_dump_{self._config.partial_name}"
|
||||
/ full_filename
|
||||
)
|
||||
path = Path(self._config.dir) / self._config.exp_name / full_filename
|
||||
|
||||
if self._config.enable_output_console:
|
||||
print(
|
||||
@@ -468,11 +465,13 @@ class _Dumper:
|
||||
_start_http_server(prefix="/dumper/", target=self, http_port=http_port)
|
||||
print(f"[Dumper] HTTP server started on port {http_port}")
|
||||
|
||||
def _ensure_partial_name(self):
|
||||
if self._config.partial_name is None:
|
||||
name = _get_partial_name(timeout_seconds=self._config.collective_timeout)
|
||||
self.configure(partial_name=name)
|
||||
print(f"[Dumper] Choose partial_name={name}")
|
||||
def _ensure_exp_name(self):
|
||||
if self._config.exp_name is None:
|
||||
name = _get_default_exp_name(
|
||||
timeout_seconds=self._config.collective_timeout
|
||||
)
|
||||
self.configure(exp_name=name)
|
||||
print(f"[Dumper] Choose exp_name={name}")
|
||||
|
||||
|
||||
# -------------------------------------- hook dumper ------------------------------------------
|
||||
@@ -496,11 +495,16 @@ class _NonIntrusiveDumper:
|
||||
if ctx := self._detect_module_ctx(module_name, module):
|
||||
self._register_ctx_hooks(module, ctx=ctx)
|
||||
|
||||
module.register_forward_hook(
|
||||
self._make_forward_hook(
|
||||
module_name=module_name,
|
||||
is_root=(module_name == ""),
|
||||
)
|
||||
is_root = module_name == ""
|
||||
pre_hook = self._make_forward_pre_hook(
|
||||
module_name=module_name, is_root=is_root
|
||||
)
|
||||
hook = self._make_forward_hook(module_name=module_name, is_root=is_root)
|
||||
_register_forward_hook_or_replace_fn(
|
||||
module,
|
||||
pre_hook=pre_hook,
|
||||
hook=hook,
|
||||
mode="replace_fn" if is_root else "hook",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -525,11 +529,15 @@ class _NonIntrusiveDumper:
|
||||
)
|
||||
)
|
||||
|
||||
def _make_forward_hook(self, *, module_name: str, is_root: bool):
|
||||
def _hook(_module, input, output):
|
||||
def _make_forward_pre_hook(self, *, module_name: str, is_root: bool):
|
||||
def _hook(_module, input):
|
||||
for i, item in enumerate(input):
|
||||
self._dump_value(module_name, item, role=f"inputs.{i}", is_root=is_root)
|
||||
|
||||
return _hook
|
||||
|
||||
def _make_forward_hook(self, *, module_name: str, is_root: bool):
|
||||
def _hook(_module, input, output):
|
||||
if output is not None:
|
||||
self._dump_value(module_name, output, role="output", is_root=False)
|
||||
|
||||
@@ -566,6 +574,39 @@ class _NonIntrusiveDumper:
|
||||
return {}
|
||||
|
||||
|
||||
def _register_forward_hook_or_replace_fn(
|
||||
module: "torch.nn.Module",
|
||||
*,
|
||||
pre_hook,
|
||||
hook,
|
||||
mode: str,
|
||||
) -> None:
|
||||
"""Attach pre/post forward hooks to *module*.
|
||||
|
||||
mode="hook" — standard ``register_forward_pre_hook`` / ``register_forward_hook``
|
||||
(fires only via ``__call__``).
|
||||
mode="replace_fn" — monkey-patch ``module.forward`` so hooks fire even when
|
||||
callers invoke ``.forward()`` directly (as sglang does for the
|
||||
root model).
|
||||
"""
|
||||
if mode == "hook":
|
||||
module.register_forward_pre_hook(pre_hook)
|
||||
module.register_forward_hook(hook)
|
||||
elif mode == "replace_fn":
|
||||
original_forward = module.forward
|
||||
|
||||
@functools.wraps(original_forward)
|
||||
def _wrapped(*args, **kwargs):
|
||||
pre_hook(module, args)
|
||||
output = original_forward(*args, **kwargs)
|
||||
hook(module, args, output)
|
||||
return output
|
||||
|
||||
module.forward = _wrapped
|
||||
else:
|
||||
raise ValueError(f"Unknown mode {mode!r}")
|
||||
|
||||
|
||||
# -------------------------------------- util fn ------------------------------------------
|
||||
|
||||
|
||||
@@ -604,14 +645,14 @@ def _collective_with_timeout(fn, operation_name: str, timeout_seconds: int = 60)
|
||||
completed.set()
|
||||
|
||||
|
||||
def _get_partial_name(timeout_seconds: int = 60):
|
||||
def _get_default_exp_name(timeout_seconds: int = 60):
|
||||
rank = _get_rank()
|
||||
object_list = [str(time.time()) if rank == 0 else None]
|
||||
object_list = [f"dump_{time.time()}" if rank == 0 else None]
|
||||
|
||||
if dist.is_initialized():
|
||||
_collective_with_timeout(
|
||||
lambda: dist.broadcast_object_list(object_list, device="cuda"),
|
||||
operation_name="broadcast_object_list in _get_partial_name",
|
||||
operation_name="broadcast_object_list in _get_default_exp_name",
|
||||
timeout_seconds=timeout_seconds,
|
||||
)
|
||||
|
||||
@@ -622,7 +663,7 @@ def _cleanup_old_dumps(base_dir: Path) -> None:
|
||||
import shutil
|
||||
|
||||
if _get_rank() == 0:
|
||||
for entry in base_dir.glob("sglang_dump_*"):
|
||||
for entry in base_dir.glob("dump_*"):
|
||||
if entry.is_dir():
|
||||
shutil.rmtree(entry)
|
||||
print(f"[Dumper] Cleaned up {entry}")
|
||||
|
||||
Reference in New Issue
Block a user