mirror of
https://github.com/kvcache-ai/sglang.git
synced 2026-06-30 19:57:52 +00:00
Co-authored-by: AdityaVKochar <adityavardhankochar@gmail.com> Co-authored-by: mintlify[bot] <109931778+mintlify[bot]@users.noreply.github.com> Co-authored-by: adhyan-jain <adhyanjain2006@gmail.com> Co-authored-by: Adhyan Jain <71976554+adhyan-jain@users.noreply.github.com> Co-authored-by: Maitri-shah29 <maitrirajivshah@gmail.com> Co-authored-by: Adarsh Shirawalmath <114558126+adarshxs@users.noreply.github.com> Co-authored-by: Maitri Shah <shah29maitri@gmail.com> Co-authored-by: Aditya Vardhan Kochar <80113212+AdityaVKochar@users.noreply.github.com> Co-authored-by: Rishit Shivam <164783543+pokymono@users.noreply.github.com> Co-authored-by: Rishitshivam <164783543+Rishitshivam@users.noreply.github.com> Co-authored-by: IshhanKheria <ishhankheria06@gmail.com> Co-authored-by: Ishita Joshi <ishitata.joshi@gmail.com> Co-authored-by: Richard Chen <104477092+Richardczl98@users.noreply.github.com> Co-authored-by: longGGGGGG <553746008@qq.com> Co-authored-by: Richard <richardchen@radixark.ai> Co-authored-by: Nakul Sinha <nakul.new4socials@gmail.com> Co-authored-by: Divyam Agrawal <ludicrouslytrue@gmail.com> Co-authored-by: Richardczl98 <Zhenlinc@stanford.edu> Co-authored-by: Krishang Zinzuwadia <krishangzinzuwadia@gmail.com> Co-authored-by: nimeshas <nimesha.s106@gmail.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Jignas Paturu <86356085+JignasP@users.noreply.github.com> Co-authored-by: zijiexia <37504505+zijiexia@users.noreply.github.com>
299 lines
8.4 KiB
Plaintext
299 lines
8.4 KiB
Plaintext
---
|
||
title: "Model Forward Hooks"
|
||
metatags:
|
||
description: "SGLang forward hooks: attach PyTorch hooks to model submodules via JSON config. Log activations, debug internals, export hidden states."
|
||
---
|
||
|
||
## Model Hooks
|
||
|
||
SGLang supports attaching PyTorch forward hooks to specific submodules in the loaded model, configured entirely via `server_args` JSON.
|
||
|
||
This is useful for:
|
||
|
||
* Logging intermediate activations
|
||
* Debugging model internals
|
||
* Exporting hidden states to external tooling
|
||
|
||
Hooks are attached once during `ModelRunner.initialize` and run on every forward pass.
|
||
|
||
***
|
||
### Configuration overview
|
||
|
||
Hooks are configured via a `ServerArgs` field:
|
||
|
||
```python Example
|
||
class ServerArgs:
|
||
...
|
||
# For forward hooks
|
||
forward_hooks: Optional[List[dict[str, Any]]] = None
|
||
````
|
||
|
||
In JSON form, a minimal configuration looks like:
|
||
|
||
```jsonc Example
|
||
{
|
||
"forward_hooks": [
|
||
{
|
||
"name": "outer_linear_hooks",
|
||
"target_modules": ["outer.0", "outer.1"],
|
||
"hook_factory": "my_project.hooks:dummy_hook_factory",
|
||
"config": {
|
||
"tag": "outer-layer"
|
||
}
|
||
}
|
||
]
|
||
}
|
||
```
|
||
|
||
#### Top-level fields
|
||
|
||
* `forward_hooks` (optional list of objects)
|
||
Each element is a hook spec describing:
|
||
|
||
* Which modules to target
|
||
* Which Python factory to call
|
||
* What configuration to pass into that factory
|
||
|
||
***
|
||
### Hook spec schema
|
||
|
||
Each entry in `forward_hooks` is a JSON object with the following shape:
|
||
|
||
```jsonc Example
|
||
{
|
||
"name": "optional-descriptive-name",
|
||
"target_modules": ["pattern1", "pattern2", "..."],
|
||
"hook_factory": "module.submodule:factory_name",
|
||
"config": {
|
||
"...": "arbitrary JSON"
|
||
}
|
||
}
|
||
```
|
||
|
||
#### `name` (optional)
|
||
|
||
* Human-readable name for logging.
|
||
* Used only in log messages such as:
|
||
|
||
```text Output
|
||
Registered forward hook 'outer_linear_hooks' on outer.0
|
||
```
|
||
|
||
#### `target_modules` (required)
|
||
|
||
* List of **module name patterns** used to match entries in `model.named_modules()`.
|
||
* Patterns are matched using `fnmatch.fnmatch`, so:
|
||
|
||
* `"outer.0"` matches exactly `"outer.0"`.
|
||
* `"outer.*"` matches `"outer.0"`, `"outer.1"`, `"outer.inner"`, etc.
|
||
* `"outer.inner.*"` matches children under `outer.inner`.
|
||
|
||
> If no modules match the given patterns, hook registration does **not** fail.
|
||
> Instead, SGLang logs a warning and continues:
|
||
>
|
||
> ```text
|
||
> No modules matched hook spec 'name' patterns=['...']
|
||
> ```
|
||
|
||
#### `hook_factory` (required)
|
||
|
||
* String path to the Python factory function that creates the hook.
|
||
* Supported formats:
|
||
|
||
* `"package.module:factory_name"`
|
||
* `"package.module.submodule.factory_name"`
|
||
|
||
The path is resolved via:
|
||
|
||
```python Example
|
||
def resolve_callable(path: Optional[str]) -> Optional[Callable]:
|
||
if path is None:
|
||
return None
|
||
|
||
if ":" in path:
|
||
module_name, fn_name = path.split(":", 1)
|
||
else:
|
||
parts = path.split(".")
|
||
if len(parts) < 2:
|
||
raise ValueError(
|
||
f"Invalid hook callable path '{path}'. "
|
||
"Expected 'module.submodule:factory' or 'module.submodule.factory'."
|
||
)
|
||
*mod_parts, fn_name = parts
|
||
module_name = ".".join(mod_parts)
|
||
|
||
module = importlib.import_module(module_name)
|
||
try:
|
||
return getattr(module, fn_name)
|
||
except AttributeError as e:
|
||
raise AttributeError(
|
||
f"Module '{module_name}' has no attribute '{fn_name}' "
|
||
f"(from hook path '{path}')"
|
||
) from e
|
||
```
|
||
|
||
**Failure modes**:
|
||
|
||
* If the path is malformed (not enough dots and no `:`), a `ValueError` is raised at startup.
|
||
* If the module imports but the attribute is missing, an `AttributeError` is raised with a clear error message.
|
||
* If the hook factory returns `None`, a warning is logged and no hook is registered for that spec (initialization continues).
|
||
|
||
The first two cause initialization to fail fast with a descriptive error; the last one is non-fatal.
|
||
|
||
#### `config` (optional)
|
||
|
||
* Arbitrary JSON object.
|
||
* Passed directly to the hook factory as a Python `dict`.
|
||
* This lets you parameterize hook behavior from config (e.g. tags, log levels, sampling rates, etc.).
|
||
|
||
***
|
||
### Hook lifecycle and behavior
|
||
|
||
Hooks are registered in `ModelRunner.initialize()`:
|
||
|
||
```python Example
|
||
if server_args.forward_hooks:
|
||
register_forward_hooks(self.model, server_args.forward_hooks)
|
||
```
|
||
|
||
The actual registration logic is implemented by `register_forward_hooks`:
|
||
|
||
```python Example
|
||
def register_forward_hooks(model: nn.Module, hook_specs: List[dict[str, Any]]) -> None:
|
||
"""
|
||
hook_specs is a list of dicts from server_args.forward_hooks.
|
||
Attaches forward hooks to the matching modules.
|
||
"""
|
||
name_to_module = dict(model.named_modules())
|
||
|
||
for spec in hook_specs:
|
||
spec_name = spec.get("name", "")
|
||
target_patterns = spec.get("target_modules", [])
|
||
if not target_patterns:
|
||
logger.warning(
|
||
f"Hook spec '{spec_name}' has no 'target_modules', skipping"
|
||
)
|
||
continue
|
||
|
||
hook_factory_path = spec.get("hook_factory")
|
||
if not hook_factory_path:
|
||
logger.warning(
|
||
f"Hook spec '{spec_name}' has no 'hook_factory', skipping"
|
||
)
|
||
continue
|
||
|
||
config = spec.get("config") or {}
|
||
hook_factory = resolve_callable(hook_factory_path)
|
||
|
||
hook = hook_factory(config) if hook_factory else None
|
||
if hook is None:
|
||
logger.warning(
|
||
f"Hook factory '{hook_factory_path}' for spec '{spec_name}' "
|
||
"returned None, not registering any hook"
|
||
)
|
||
continue
|
||
|
||
# Resolve patterns like "model.layers.*.mlp"
|
||
matched = []
|
||
for name, module in name_to_module.items():
|
||
if any(fnmatch.fnmatch(name, pattern) for pattern in target_patterns):
|
||
matched.append((name, module))
|
||
|
||
if not matched:
|
||
logger.warning(
|
||
f"No modules matched hook spec '{spec_name}' "
|
||
f"patterns={target_patterns}"
|
||
)
|
||
continue
|
||
|
||
for module_name, module in matched:
|
||
if hook:
|
||
_ = module.register_forward_hook(hook)
|
||
logger.info(
|
||
f"Registered forward hook '{spec_name}' "
|
||
f"on {module_name}"
|
||
)
|
||
```
|
||
|
||
Key points:
|
||
|
||
* Hooks are **forward hooks only** (via `module.register_forward_hook`).
|
||
* They are attached once at initialization.
|
||
* Hook handles are currently not stored on `ModelRunner` (they cannot be removed later via this API).
|
||
* Failure to match any modules is non-fatal; a warning is logged instead.
|
||
* If a hook factory returns `None`, a warning is logged and that spec is skipped.
|
||
|
||
***
|
||
### Writing a hook factory
|
||
|
||
A hook factory is a regular Python function:
|
||
|
||
* Takes a `config: dict` (from JSON)
|
||
* Returns a forward hook function with signature `(module, inputs, output)`
|
||
|
||
Example:
|
||
|
||
```python Example
|
||
HOOK_CALLS = []
|
||
|
||
def dummy_hook_factory(config):
|
||
"""Factory that returns a forward hook capturing a tag from config."""
|
||
tag = config.get("tag", "default")
|
||
|
||
def hook(module, inputs, output):
|
||
HOOK_CALLS.append(
|
||
{
|
||
"module_type": type(module).__name__,
|
||
"tag": tag,
|
||
"shape": tuple(output.shape),
|
||
}
|
||
)
|
||
return output # must return output if you don’t want to modify the tensor
|
||
|
||
return hook
|
||
```
|
||
|
||
In JSON:
|
||
|
||
```jsonc Example
|
||
{
|
||
"forward_hooks": [
|
||
{
|
||
"name": "capture_outer",
|
||
"target_modules": ["outer.0", "outer.1"],
|
||
"hook_factory": "my_project.hooks:dummy_hook_factory",
|
||
"config": {
|
||
"tag": "outer"
|
||
}
|
||
}
|
||
]
|
||
}
|
||
```
|
||
|
||
This will:
|
||
|
||
* Resolve `my_project.hooks:dummy_hook_factory` to a Python callable.
|
||
* Call it with `config = {"tag": "outer"}`.
|
||
* Use the returned hook for all modules matching `outer.0` and `outer.1`.
|
||
* Append metadata about each call to `HOOK_CALLS`.
|
||
|
||
***
|
||
### Summary
|
||
|
||
* Define `forward_hooks` as a list of specs in `ServerArgs` to turn on the feature.
|
||
|
||
* Each spec:
|
||
|
||
* selects modules via `target_modules` (glob patterns over `model.named_modules()`),
|
||
* points to a hook factory via `hook_factory`,
|
||
* passes arbitrary `config` into that factory.
|
||
|
||
* Hook factories are resolved via `resolve_callable`, which supports `module:factory` and `module.submodule.factory`.
|
||
|
||
* Hooks are standard PyTorch forward hooks, attached once at startup and invoked on every forward pass.
|
||
|
||
* Misconfiguration is either:
|
||
|
||
* **fatal and explicit** (bad path / missing attribute), or
|
||
* **non-fatal with clear warnings** (no targets matched, or factory returned `None`).
|