Files
sglang/docs_new/docs/advanced_features/forward_hooks.mdx
Mingyi a3291b5654 Add new Mintlify documentation site (docs_new/) (#23001)
Co-authored-by: AdityaVKochar <adityavardhankochar@gmail.com>
Co-authored-by: mintlify[bot] <109931778+mintlify[bot]@users.noreply.github.com>
Co-authored-by: adhyan-jain <adhyanjain2006@gmail.com>
Co-authored-by: Adhyan Jain <71976554+adhyan-jain@users.noreply.github.com>
Co-authored-by: Maitri-shah29 <maitrirajivshah@gmail.com>
Co-authored-by: Adarsh Shirawalmath <114558126+adarshxs@users.noreply.github.com>
Co-authored-by: Maitri Shah <shah29maitri@gmail.com>
Co-authored-by: Aditya Vardhan Kochar <80113212+AdityaVKochar@users.noreply.github.com>
Co-authored-by: Rishit Shivam <164783543+pokymono@users.noreply.github.com>
Co-authored-by: Rishitshivam <164783543+Rishitshivam@users.noreply.github.com>
Co-authored-by: IshhanKheria <ishhankheria06@gmail.com>
Co-authored-by: Ishita Joshi <ishitata.joshi@gmail.com>
Co-authored-by: Richard Chen <104477092+Richardczl98@users.noreply.github.com>
Co-authored-by: longGGGGGG <553746008@qq.com>
Co-authored-by: Richard <richardchen@radixark.ai>
Co-authored-by: Nakul Sinha <nakul.new4socials@gmail.com>
Co-authored-by: Divyam Agrawal <ludicrouslytrue@gmail.com>
Co-authored-by: Richardczl98 <Zhenlinc@stanford.edu>
Co-authored-by: Krishang Zinzuwadia <krishangzinzuwadia@gmail.com>
Co-authored-by: nimeshas <nimesha.s106@gmail.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Jignas Paturu <86356085+JignasP@users.noreply.github.com>
Co-authored-by: zijiexia <37504505+zijiexia@users.noreply.github.com>
2026-04-20 15:10:22 -07:00

299 lines
8.4 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
---
title: "Model Forward Hooks"
metatags:
description: "SGLang forward hooks: attach PyTorch hooks to model submodules via JSON config. Log activations, debug internals, export hidden states."
---
## Model Hooks
SGLang supports attaching PyTorch forward hooks to specific submodules in the loaded model, configured entirely via `server_args` JSON.
This is useful for:
* Logging intermediate activations
* Debugging model internals
* Exporting hidden states to external tooling
Hooks are attached once during `ModelRunner.initialize` and run on every forward pass.
***
### Configuration overview
Hooks are configured via a `ServerArgs` field:
```python Example
class ServerArgs:
...
# For forward hooks
forward_hooks: Optional[List[dict[str, Any]]] = None
````
In JSON form, a minimal configuration looks like:
```jsonc Example
{
"forward_hooks": [
{
"name": "outer_linear_hooks",
"target_modules": ["outer.0", "outer.1"],
"hook_factory": "my_project.hooks:dummy_hook_factory",
"config": {
"tag": "outer-layer"
}
}
]
}
```
#### Top-level fields
* `forward_hooks` (optional list of objects)
Each element is a hook spec describing:
* Which modules to target
* Which Python factory to call
* What configuration to pass into that factory
***
### Hook spec schema
Each entry in `forward_hooks` is a JSON object with the following shape:
```jsonc Example
{
"name": "optional-descriptive-name",
"target_modules": ["pattern1", "pattern2", "..."],
"hook_factory": "module.submodule:factory_name",
"config": {
"...": "arbitrary JSON"
}
}
```
#### `name` (optional)
* Human-readable name for logging.
* Used only in log messages such as:
```text Output
Registered forward hook 'outer_linear_hooks' on outer.0
```
#### `target_modules` (required)
* List of **module name patterns** used to match entries in `model.named_modules()`.
* Patterns are matched using `fnmatch.fnmatch`, so:
* `"outer.0"` matches exactly `"outer.0"`.
* `"outer.*"` matches `"outer.0"`, `"outer.1"`, `"outer.inner"`, etc.
* `"outer.inner.*"` matches children under `outer.inner`.
> If no modules match the given patterns, hook registration does **not** fail.
> Instead, SGLang logs a warning and continues:
>
> ```text
> No modules matched hook spec 'name' patterns=['...']
> ```
#### `hook_factory` (required)
* String path to the Python factory function that creates the hook.
* Supported formats:
* `"package.module:factory_name"`
* `"package.module.submodule.factory_name"`
The path is resolved via:
```python Example
def resolve_callable(path: Optional[str]) -> Optional[Callable]:
if path is None:
return None
if ":" in path:
module_name, fn_name = path.split(":", 1)
else:
parts = path.split(".")
if len(parts) < 2:
raise ValueError(
f"Invalid hook callable path '{path}'. "
"Expected 'module.submodule:factory' or 'module.submodule.factory'."
)
*mod_parts, fn_name = parts
module_name = ".".join(mod_parts)
module = importlib.import_module(module_name)
try:
return getattr(module, fn_name)
except AttributeError as e:
raise AttributeError(
f"Module '{module_name}' has no attribute '{fn_name}' "
f"(from hook path '{path}')"
) from e
```
**Failure modes**:
* If the path is malformed (not enough dots and no `:`), a `ValueError` is raised at startup.
* If the module imports but the attribute is missing, an `AttributeError` is raised with a clear error message.
* If the hook factory returns `None`, a warning is logged and no hook is registered for that spec (initialization continues).
The first two cause initialization to fail fast with a descriptive error; the last one is non-fatal.
#### `config` (optional)
* Arbitrary JSON object.
* Passed directly to the hook factory as a Python `dict`.
* This lets you parameterize hook behavior from config (e.g. tags, log levels, sampling rates, etc.).
***
### Hook lifecycle and behavior
Hooks are registered in `ModelRunner.initialize()`:
```python Example
if server_args.forward_hooks:
register_forward_hooks(self.model, server_args.forward_hooks)
```
The actual registration logic is implemented by `register_forward_hooks`:
```python Example
def register_forward_hooks(model: nn.Module, hook_specs: List[dict[str, Any]]) -> None:
"""
hook_specs is a list of dicts from server_args.forward_hooks.
Attaches forward hooks to the matching modules.
"""
name_to_module = dict(model.named_modules())
for spec in hook_specs:
spec_name = spec.get("name", "")
target_patterns = spec.get("target_modules", [])
if not target_patterns:
logger.warning(
f"Hook spec '{spec_name}' has no 'target_modules', skipping"
)
continue
hook_factory_path = spec.get("hook_factory")
if not hook_factory_path:
logger.warning(
f"Hook spec '{spec_name}' has no 'hook_factory', skipping"
)
continue
config = spec.get("config") or {}
hook_factory = resolve_callable(hook_factory_path)
hook = hook_factory(config) if hook_factory else None
if hook is None:
logger.warning(
f"Hook factory '{hook_factory_path}' for spec '{spec_name}' "
"returned None, not registering any hook"
)
continue
# Resolve patterns like "model.layers.*.mlp"
matched = []
for name, module in name_to_module.items():
if any(fnmatch.fnmatch(name, pattern) for pattern in target_patterns):
matched.append((name, module))
if not matched:
logger.warning(
f"No modules matched hook spec '{spec_name}' "
f"patterns={target_patterns}"
)
continue
for module_name, module in matched:
if hook:
_ = module.register_forward_hook(hook)
logger.info(
f"Registered forward hook '{spec_name}' "
f"on {module_name}"
)
```
Key points:
* Hooks are **forward hooks only** (via `module.register_forward_hook`).
* They are attached once at initialization.
* Hook handles are currently not stored on `ModelRunner` (they cannot be removed later via this API).
* Failure to match any modules is non-fatal; a warning is logged instead.
* If a hook factory returns `None`, a warning is logged and that spec is skipped.
***
### Writing a hook factory
A hook factory is a regular Python function:
* Takes a `config: dict` (from JSON)
* Returns a forward hook function with signature `(module, inputs, output)`
Example:
```python Example
HOOK_CALLS = []
def dummy_hook_factory(config):
"""Factory that returns a forward hook capturing a tag from config."""
tag = config.get("tag", "default")
def hook(module, inputs, output):
HOOK_CALLS.append(
{
"module_type": type(module).__name__,
"tag": tag,
"shape": tuple(output.shape),
}
)
return output # must return output if you dont want to modify the tensor
return hook
```
In JSON:
```jsonc Example
{
"forward_hooks": [
{
"name": "capture_outer",
"target_modules": ["outer.0", "outer.1"],
"hook_factory": "my_project.hooks:dummy_hook_factory",
"config": {
"tag": "outer"
}
}
]
}
```
This will:
* Resolve `my_project.hooks:dummy_hook_factory` to a Python callable.
* Call it with `config = {"tag": "outer"}`.
* Use the returned hook for all modules matching `outer.0` and `outer.1`.
* Append metadata about each call to `HOOK_CALLS`.
***
### Summary
* Define `forward_hooks` as a list of specs in `ServerArgs` to turn on the feature.
* Each spec:
* selects modules via `target_modules` (glob patterns over `model.named_modules()`),
* points to a hook factory via `hook_factory`,
* passes arbitrary `config` into that factory.
* Hook factories are resolved via `resolve_callable`, which supports `module:factory` and `module.submodule.factory`.
* Hooks are standard PyTorch forward hooks, attached once at startup and invoked on every forward pass.
* Misconfiguration is either:
* **fatal and explicit** (bad path / missing attribute), or
* **non-fatal with clear warnings** (no targets matched, or factory returned `None`).