mirror of
https://github.com/kvcache-ai/sglang.git
synced 2026-06-30 19:57:52 +00:00
314 lines
9.7 KiB
Python
314 lines
9.7 KiB
Python
"""
|
|
Full server warmup to pre-warm Triton autotuning and CUDA graph capture.
|
|
|
|
On cold H200 nodes (new nodes or after container recreation), CUDA graph capture
|
|
triggers Triton autotuning which takes ~330s per server launch. This script
|
|
launches actual servers with CUDA graphs enabled to cache the autotuned kernels,
|
|
so subsequent test launches are fast (~30-60s).
|
|
|
|
Uses marker files to skip warmup on already-warm nodes. Marker files are
|
|
invalidated when Python, Triton, or PyTorch versions change.
|
|
|
|
Usage:
|
|
python3 scripts/ci/cuda/warmup_server.py \
|
|
deepseek-ai/DeepSeek-V3-0324:8 \
|
|
inclusionAI/Ring-2.5-1T:8
|
|
"""
|
|
|
|
import hashlib
|
|
import json
|
|
import os
|
|
import signal
|
|
import subprocess
|
|
import sys
|
|
import tempfile
|
|
import time
|
|
from pathlib import Path
|
|
|
|
# Reuse helpers from warmup_deep_gemm (same directory)
|
|
sys.path.insert(0, os.path.dirname(__file__))
|
|
from warmup_deep_gemm import get_architecture_key, get_config_json
|
|
|
|
MARKER_DIR = os.path.join(os.path.expanduser("~"), ".cache", "sglang", "warmup_markers")
|
|
HEALTH_POLL_INTERVAL = 10 # seconds between health checks
|
|
SERVER_STARTUP_TIMEOUT = 900 # 15 min max to wait for server ready
|
|
DEFAULT_PORT = 39876
|
|
|
|
|
|
def get_version_key():
|
|
"""Hash of Python + Triton + PyTorch versions to invalidate markers on upgrades."""
|
|
parts = [sys.version]
|
|
try:
|
|
import triton
|
|
|
|
parts.append(f"triton={triton.__version__}")
|
|
except ImportError:
|
|
parts.append("triton=none")
|
|
try:
|
|
import torch
|
|
|
|
parts.append(f"torch={torch.__version__}")
|
|
except ImportError:
|
|
parts.append("torch=none")
|
|
return hashlib.sha256("|".join(parts).encode()).hexdigest()[:12]
|
|
|
|
|
|
def get_marker_path(model, tp):
|
|
"""Get the marker file path for a model:tp pair."""
|
|
version_key = get_version_key()
|
|
safe_model = model.replace("/", "--")
|
|
return os.path.join(
|
|
MARKER_DIR, f"server_warmup_{safe_model}_tp{tp}_{version_key}.done"
|
|
)
|
|
|
|
|
|
def check_marker(model, tp):
|
|
"""Check if warmup marker exists (node already warm)."""
|
|
marker = get_marker_path(model, tp)
|
|
return os.path.exists(marker)
|
|
|
|
|
|
def write_marker(model, tp):
|
|
"""Write warmup marker after successful warmup."""
|
|
marker = get_marker_path(model, tp)
|
|
os.makedirs(os.path.dirname(marker), exist_ok=True)
|
|
Path(marker).write_text(
|
|
json.dumps(
|
|
{
|
|
"model": model,
|
|
"tp": tp,
|
|
"version_key": get_version_key(),
|
|
"timestamp": time.time(),
|
|
}
|
|
)
|
|
)
|
|
print(f" Wrote marker: {marker}")
|
|
|
|
|
|
def kill_server(proc):
|
|
"""Kill server process tree."""
|
|
if proc.poll() is not None:
|
|
return
|
|
try:
|
|
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
|
|
except (ProcessLookupError, OSError):
|
|
pass
|
|
try:
|
|
proc.wait(timeout=15)
|
|
except subprocess.TimeoutExpired:
|
|
try:
|
|
os.killpg(os.getpgid(proc.pid), signal.SIGKILL)
|
|
except (ProcessLookupError, OSError):
|
|
pass
|
|
try:
|
|
proc.wait(timeout=5)
|
|
except subprocess.TimeoutExpired:
|
|
pass
|
|
|
|
|
|
def wait_for_server(base_url, proc, timeout):
|
|
"""Poll /health_generate until server is ready or timeout."""
|
|
import requests
|
|
|
|
start = time.time()
|
|
while time.time() - start < timeout:
|
|
ret = proc.poll()
|
|
if ret is not None:
|
|
return False, f"Server exited with code {ret}"
|
|
try:
|
|
resp = requests.get(f"{base_url}/health_generate", timeout=5)
|
|
if resp.status_code == 200:
|
|
return True, None
|
|
except requests.RequestException:
|
|
pass
|
|
time.sleep(HEALTH_POLL_INTERVAL)
|
|
return False, "Timed out waiting for server"
|
|
|
|
|
|
def send_generate_request(base_url):
|
|
"""Send one /generate request to exercise the full inference path."""
|
|
import requests
|
|
|
|
payload = {
|
|
"input_ids": [0, 1, 2, 3],
|
|
"sampling_params": {
|
|
"max_new_tokens": 8,
|
|
"temperature": 0,
|
|
},
|
|
}
|
|
try:
|
|
resp = requests.post(f"{base_url}/generate", json=payload, timeout=120)
|
|
if resp.status_code == 200:
|
|
print(" Generate request succeeded")
|
|
else:
|
|
print(f" Warning: generate request returned {resp.status_code}")
|
|
except requests.RequestException as e:
|
|
print(f" Warning: generate request failed: {e}")
|
|
|
|
|
|
def warmup_one_model(model, tp, port):
|
|
"""Launch server, wait for ready, send one request, then kill."""
|
|
base_url = f"http://127.0.0.1:{port}"
|
|
|
|
cmd = [
|
|
sys.executable,
|
|
"-m",
|
|
"sglang.launch_server",
|
|
"--model-path",
|
|
model,
|
|
"--tp",
|
|
str(tp),
|
|
"--host",
|
|
"127.0.0.1",
|
|
"--port",
|
|
str(port),
|
|
"--trust-remote-code",
|
|
"--model-loader-extra-config",
|
|
'{"enable_multithread_load": true, "num_threads": 64}',
|
|
]
|
|
|
|
# Use a temp file for server output to avoid pipe buffer deadlock
|
|
# (server logs can exceed the 64KB pipe buffer during CUDA graph capture)
|
|
log_file = tempfile.NamedTemporaryFile(
|
|
mode="w", prefix="warmup_server_", suffix=".log", delete=False
|
|
)
|
|
log_path = log_file.name
|
|
|
|
print(f" Launching server: {' '.join(cmd)}")
|
|
print(f" Server log: {log_path}")
|
|
proc = subprocess.Popen(
|
|
cmd,
|
|
stdout=log_file,
|
|
stderr=subprocess.STDOUT,
|
|
preexec_fn=os.setsid,
|
|
)
|
|
|
|
try:
|
|
# Wait for server to be ready (includes CUDA graph capture)
|
|
print(
|
|
f" Waiting for server (timeout={SERVER_STARTUP_TIMEOUT}s, "
|
|
f"polling every {HEALTH_POLL_INTERVAL}s)..."
|
|
)
|
|
ok, err = wait_for_server(base_url, proc, SERVER_STARTUP_TIMEOUT)
|
|
if not ok:
|
|
print(f" Warning: server not ready: {err}")
|
|
# Dump last lines of server log for debugging
|
|
try:
|
|
log_file.flush()
|
|
with open(log_path) as f:
|
|
lines = f.readlines()
|
|
for line in lines[-20:]:
|
|
print(f" | {line.rstrip()}")
|
|
except Exception:
|
|
pass
|
|
return False
|
|
|
|
print(" Server ready, sending generate request...")
|
|
send_generate_request(base_url)
|
|
return True
|
|
|
|
finally:
|
|
print(" Killing server...")
|
|
kill_server(proc)
|
|
log_file.close()
|
|
try:
|
|
os.unlink(log_path)
|
|
except OSError:
|
|
pass
|
|
|
|
|
|
def main():
|
|
if len(sys.argv) < 2 or sys.argv[1] in ("-h", "--help"):
|
|
print("Usage: warmup_server.py model1:tp1 [model2:tp2 ...]")
|
|
print(
|
|
"\nLaunches full servers with CUDA graphs enabled to pre-warm"
|
|
" Triton autotuning."
|
|
)
|
|
print("Skips instantly on warm nodes (marker file exists).")
|
|
sys.exit(0)
|
|
|
|
# Parse model:tp pairs
|
|
model_tp_pairs = []
|
|
for arg in sys.argv[1:]:
|
|
if ":" not in arg:
|
|
print(f"Error: expected model:tp format, got '{arg}'")
|
|
sys.exit(1)
|
|
model, tp_str = arg.rsplit(":", 1)
|
|
model_tp_pairs.append((model, int(tp_str)))
|
|
|
|
print(f"=== Server CUDA Graph Warmup ({len(model_tp_pairs)} model(s)) ===")
|
|
print(f" Marker dir: {MARKER_DIR}")
|
|
print(f" Version key: {get_version_key()}\n")
|
|
|
|
# Deduplicate by architecture and check markers
|
|
seen_keys = {}
|
|
to_warmup = []
|
|
|
|
for model, tp in model_tp_pairs:
|
|
# Check marker first (fast path)
|
|
if check_marker(model, tp):
|
|
print(f" SKIP {model} (tp={tp}): already warm (marker exists)")
|
|
continue
|
|
|
|
# Architecture dedup
|
|
config = get_config_json(model)
|
|
if config is not None:
|
|
key = get_architecture_key(config, tp)
|
|
if key in seen_keys:
|
|
print(
|
|
f" DEDUP {model} (tp={tp}): same architecture as {seen_keys[key]}"
|
|
)
|
|
continue
|
|
seen_keys[key] = model
|
|
|
|
to_warmup.append((model, tp))
|
|
print(f" QUEUE {model} (tp={tp}): needs warmup")
|
|
|
|
if not to_warmup:
|
|
print("\nAll models already warm. Done.")
|
|
return
|
|
|
|
print(f"\n{len(to_warmup)} model(s) to warm up.\n")
|
|
|
|
port = DEFAULT_PORT
|
|
for i, (model, tp) in enumerate(to_warmup, 1):
|
|
print(f"\n{'=' * 60}")
|
|
print(f"[{i}/{len(to_warmup)}] {model} (tp={tp})")
|
|
print(f"{'=' * 60}")
|
|
|
|
t0 = time.time()
|
|
success = warmup_one_model(model, tp, port)
|
|
elapsed = time.time() - t0
|
|
|
|
if success:
|
|
print(f" Completed in {elapsed:.0f}s")
|
|
write_marker(model, tp)
|
|
# Also write markers for dedup'd models that share this architecture
|
|
config = get_config_json(model)
|
|
if config is not None:
|
|
key = get_architecture_key(config, tp)
|
|
for other_model, other_tp in model_tp_pairs:
|
|
if (other_model, other_tp) == (model, tp):
|
|
continue
|
|
other_config = get_config_json(other_model)
|
|
if other_config is not None:
|
|
other_key = get_architecture_key(other_config, other_tp)
|
|
if other_key == key and not check_marker(other_model, other_tp):
|
|
write_marker(other_model, other_tp)
|
|
print(
|
|
f" Also marked {other_model} (tp={other_tp}) as warm (same arch)"
|
|
)
|
|
else:
|
|
print(
|
|
f" Warning: warmup failed after {elapsed:.0f}s (non-fatal, tests will still work)"
|
|
)
|
|
|
|
# Use a different port for the next model to avoid bind conflicts
|
|
port += 100
|
|
|
|
print("\nServer CUDA graph warmup complete.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|