[Score API] Add return_pooled_hidden_states to Scoring API for SequenceClassification / RewardModel (#22427)

This commit is contained in:
Sundara Raman Ramachandran
2026-04-15 14:58:56 -07:00
committed by GitHub
parent 4e480d5785
commit 4927975427
22 changed files with 809 additions and 63 deletions

View File

@@ -702,6 +702,7 @@ class TboForwardBatchPreparer:
"mrope_positions", # only used by qwen2-vl, thus not care
"split_index", # for split prefill
"orig_seq_lens", # only used by qwen-1m, thus not care
"return_pooled_hidden_states",
]:
output_dict[key] = getattr(batch, key)
if not batch.forward_mode.is_target_verify():

View File

@@ -1436,6 +1436,17 @@ def is_piecewise_cuda_graph_disabled_model(model_architectures: List[str]):
)
# SequenceClassification models that use CrossEncodingPooler
_cross_encoding_pooler_archs = [
"BertForSequenceClassification",
"XLMRobertaForSequenceClassification",
]
def is_cross_encoding_pooler_model(model_architectures: List[str]) -> bool:
return any(arch in _cross_encoding_pooler_archs for arch in model_architectures)
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
if scale <= 1:
return 1.0

View File

@@ -33,12 +33,10 @@ class EngineScoreMixin:
label_token_ids: Optional[List[int]] = None,
apply_softmax: bool = False,
item_first: bool = False,
# Placeholder token id in query/items that indicates override locations.
embed_override_token_id: Optional[int] = None,
# Query embedding overrides.
query_embed_overrides: Optional[List[torch.Tensor]] = None,
# Item embedding overrides: per-item list of tensors.
item_embed_overrides: Optional[List[Optional[List[torch.Tensor]]]] = None,
return_pooled_hidden_states: bool = False,
) -> ScoreResult:
"""
Score items against a query using the loaded model.
@@ -63,9 +61,13 @@ class EngineScoreMixin:
embed_override_token_id: Placeholder token ID used to locate override positions.
query_embed_overrides: Embedding vectors replacing placeholder tokens in query.
item_embed_overrides: Per-item embedding vectors replacing placeholder tokens in items.
return_pooled_hidden_states: Whether to include raw pooled transformer
hidden states (before the task head) in the result. Only supported
for non-generation models (SequenceClassification, RewardModel).
Returns:
ScoreResult with scores (one list per item) and prompt token count.
ScoreResult with scores (one list per item), prompt token count, and
optional pooled_hidden_states tensors.
"""
return self.loop.run_until_complete(
self.tokenizer_manager.score_request(
@@ -78,6 +80,7 @@ class EngineScoreMixin:
query_embed_overrides=query_embed_overrides,
item_embed_overrides=item_embed_overrides,
request=None,
return_pooled_hidden_states=return_pooled_hidden_states,
)
)
@@ -91,6 +94,7 @@ class EngineScoreMixin:
embed_override_token_id: Optional[int] = None,
query_embed_overrides: Optional[List[torch.Tensor]] = None,
item_embed_overrides: Optional[List[Optional[List[torch.Tensor]]]] = None,
return_pooled_hidden_states: bool = False,
) -> ScoreResult:
"""Asynchronous version of score(). See score() for full documentation."""
return await self.tokenizer_manager.score_request(
@@ -103,4 +107,5 @@ class EngineScoreMixin:
query_embed_overrides=query_embed_overrides,
item_embed_overrides=item_embed_overrides,
request=None,
return_pooled_hidden_states=return_pooled_hidden_states,
)

View File

@@ -1016,6 +1016,7 @@ class ScoringRequest(BaseModel):
)
apply_softmax: bool = False
item_first: bool = False
return_pooled_hidden_states: bool = False
model: str = DEFAULT_MODEL_NAME
@@ -1023,6 +1024,7 @@ class ScoringResponse(BaseModel):
scores: List[
List[float]
] # List of lists of probabilities, each in the order of label_token_ids
pooled_hidden_states: Optional[List[Optional[List[float]]]] = None
model: str
usage: Optional[UsageInfo] = None
object: str = "scoring"

View File

@@ -3,6 +3,7 @@ from typing import Union
import torch
from fastapi import Request
from fastapi.responses import ORJSONResponse
from sglang.srt.entrypoints.openai.protocol import (
ErrorResponse,
@@ -76,17 +77,26 @@ class OpenAIServingScore(OpenAIServingBase):
query_embed_overrides=query_embed_overrides,
item_embed_overrides=item_embed_overrides,
request=raw_request,
return_pooled_hidden_states=request.return_pooled_hidden_states,
)
phs_as_lists = None
if result.pooled_hidden_states is not None:
phs_as_lists = [
t.tolist() if t is not None else None
for t in result.pooled_hidden_states
]
response = ScoringResponse(
scores=result.scores,
pooled_hidden_states=phs_as_lists,
model=request.model,
usage=UsageInfo(
prompt_tokens=result.prompt_tokens,
total_tokens=result.prompt_tokens,
),
)
return response
return ORJSONResponse(content=response.model_dump())
except ValueError as e:
return self.create_error_response(str(e))

View File

@@ -1,18 +1,22 @@
# adapted from
# https://github.com/vllm-project/vllm/blob/82a1b1a82b1fbb454c82a9ef95730b929c9b270c/vllm/model_executor/layers/pooler.py
from __future__ import annotations
from dataclasses import dataclass
from enum import IntEnum
from typing import Optional
from typing import TYPE_CHECKING, List, Optional
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from sglang.srt.layers.activation import get_cross_encoder_activation_function
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import get_global_server_args
if TYPE_CHECKING:
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class PoolingType(IntEnum):
LAST = 0
@@ -21,9 +25,45 @@ class PoolingType(IntEnum):
@dataclass
class EmbeddingPoolerOutput:
"""Output of pooler or score_and_pool.
Attributes:
embeddings: Pooled embeddings or classification logits. May be a list
of tensors when per-request matryoshka dim truncation produces
different shapes, or when MIS yields a variable number of scores
per request.
pooled_hidden_states: Raw transformer hidden states *before* the
task-specific head, present only when
``forward_batch.return_pooled_hidden_states`` is True. Tensor
(standard path) or list of tensors (MIS path, one per delimiter).
"""
# Pooler can return list[tensor] instead of tensor if the dimension of each tensor in the batch is different
# due to different per-request matryoshka dim truncation
embeddings: torch.Tensor | list[torch.Tensor]
pooled_hidden_states: Optional[torch.Tensor | list[torch.Tensor]] = None
def pool_hidden_states(
pooling_type: PoolingType,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
"""Pool hidden_states by PoolingType (LAST/CLS).
Raw pooling only — no normalize, no dim truncation.
Returns shape (batch_size, hidden_size).
"""
if pooling_type == PoolingType.LAST:
last_token_indices = torch.cumsum(forward_batch.extend_seq_lens, dim=0) - 1
return hidden_states[last_token_indices]
elif pooling_type == PoolingType.CLS:
prompt_lens = forward_batch.extend_seq_lens
first_token_flat_indices = torch.zeros_like(prompt_lens)
first_token_flat_indices[1:] += torch.cumsum(prompt_lens, dim=0)[:-1]
return hidden_states[first_token_flat_indices]
else:
raise ValueError(f"Unsupported pooling type: {pooling_type}")
def score_and_pool(
@@ -33,16 +73,16 @@ def score_and_pool(
forward_batch: ForwardBatch,
input_ids: torch.Tensor,
) -> EmbeddingPoolerOutput:
"""Apply a classification/score head with multi-item scoring (MIS) support.
"""Apply a classification/score head with MIS and pooled-hidden-states support.
When ``multi_item_scoring_delimiter`` is configured and found in
``input_ids``, takes the MIS path: extract hidden states at the positions
just before each delimiter, apply the score head only to those positions,
then split results per-request using ``forward_batch.extend_seq_lens``.
MIS path (when ``multi_item_scoring_delimiter`` is set and found in ``input_ids``):
extract hidden states at positions just before each delimiter, apply the score head,
then split per-request.
Otherwise, takes the normal single-item path: apply the score head to all
hidden states, then pool (matching the original classification model
forward logic).
Standard path: apply the score head to all hidden states, then pool.
When ``forward_batch.return_pooled_hidden_states`` is True, the raw pooled
hidden states (before the score head) are included in the output.
"""
delimiter_token = get_global_server_args().multi_item_scoring_delimiter
if delimiter_token is not None and forward_batch.is_prefill_only:
@@ -52,26 +92,40 @@ def score_and_pool(
if delim_positions.numel() > 0:
# Score only the tokens that precede a delimiter
scores = score_head(hidden_states[delim_positions - 1])
pre_delim_hidden = hidden_states[delim_positions - 1]
scores = score_head(pre_delim_hidden)
# Split per-request so the scheduler gets one tensor per request.
# Use CPU sequence lengths to avoid per-iteration GPU<->CPU sync
# from `.item()` calls on device tensors.
seq_lens = forward_batch.extend_seq_lens_cpu
start = 0
per_request = []
per_request_scores: List[torch.Tensor] = []
per_request_phs: Optional[List[torch.Tensor]] = (
[] if forward_batch.return_pooled_hidden_states else None
)
for seq_len in seq_lens:
end = start + seq_len
mask = (delim_positions >= start) & (delim_positions < end)
per_request.append(scores[mask])
per_request_scores.append(scores[mask])
if per_request_phs is not None:
per_request_phs.append(pre_delim_hidden[mask])
start = end
return EmbeddingPoolerOutput(embeddings=per_request)
return EmbeddingPoolerOutput(
embeddings=per_request_scores,
pooled_hidden_states=per_request_phs,
)
# Standard classification path: score all tokens, then pool.
logits = score_head(hidden_states)
pooled_logits = pooler(logits, forward_batch).embeddings
return EmbeddingPoolerOutput(embeddings=pooled_logits)
# Standard classification path: pool hidden states, then score.
pooled_hs = pool_hidden_states(pooler.pooling_type, hidden_states, forward_batch)
scores = score_head(pooled_hs)
return EmbeddingPoolerOutput(
embeddings=scores,
pooled_hidden_states=(
pooled_hs if forward_batch.return_pooled_hidden_states else None
),
)
class Pooler(nn.Module):
@@ -93,17 +147,9 @@ class Pooler(nn.Module):
def forward(
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
) -> EmbeddingPoolerOutput:
if self.pooling_type == PoolingType.LAST:
last_token_indices = torch.cumsum(forward_batch.extend_seq_lens, dim=0) - 1
pooled_data = hidden_states[last_token_indices]
elif self.pooling_type == PoolingType.CLS:
prompt_lens = forward_batch.extend_seq_lens
first_token_flat_indices = torch.zeros_like(prompt_lens)
first_token_flat_indices[1:] += torch.cumsum(prompt_lens, dim=0)[:-1]
pooled_data = hidden_states[first_token_flat_indices]
else:
raise ValueError(f"Invalid pooling type: {self.pooling_type}")
pooled_data = pool_hidden_states(
self.pooling_type, hidden_states, forward_batch
)
if forward_batch.dimensions is not None:
all_same_dimensions = len(set(forward_batch.dimensions)) == 1

View File

@@ -852,6 +852,9 @@ class EmbeddingReqInput(BaseReq):
# The uid of LoRA adaptors, should be initialized by tokenizer manager
lora_id: Optional[Union[List[Optional[str]], Optional[str]]] = None
# Whether to return pooled hidden states (pre-head transformer output)
return_pooled_hidden_states: bool = False
def normalize_batch_and_arguments(self):
# at least one of text, input_ids, or image should be provided
if self.text is None and self.input_ids is None and self.image_data is None:
@@ -953,6 +956,7 @@ class EmbeddingReqInput(BaseReq):
lora_id=self.lora_id[i] if self.lora_id is not None else None,
is_cross_encoder_request=True,
http_worker_ipc=self.http_worker_ipc,
return_pooled_hidden_states=self.return_pooled_hidden_states,
)
else:
sub = EmbeddingReqInput(
@@ -976,6 +980,7 @@ class EmbeddingReqInput(BaseReq):
dimensions=self.dimensions,
http_worker_ipc=self.http_worker_ipc,
received_time=self.received_time,
return_pooled_hidden_states=self.return_pooled_hidden_states,
)
cache[i] = sub
return sub
@@ -1007,6 +1012,9 @@ class TokenizedEmbeddingReqInput(BaseReq):
# For observability
time_stats: Optional[Union[APIServerReqTimeStats, DPControllerReqTimeStats]] = None
# Whether to return pooled hidden states (pre-head transformer output)
return_pooled_hidden_states: bool = False
@dataclass
class BatchTokenizedEmbeddingReqInput(BaseBatchReq):
@@ -1175,6 +1183,12 @@ class BatchEmbeddingOutput(BaseBatchReq):
# For observability
time_stats: Optional[List[SchedulerReqTimeStats]] = None
# Optional pooled hidden states (pre-head transformer output).
# Sent as a single stacked tensor to minimize pickle overhead.
pooled_hidden_states: Optional[
Union[List[Optional[torch.Tensor]], torch.Tensor]
] = None
@dataclass
class ClearHiCacheReqInput(BaseReq):

View File

@@ -596,6 +596,7 @@ class Req(ReqDllmMixin):
time_stats: Optional[
Union[APIServerReqTimeStats, DPControllerReqTimeStats]
] = None,
return_pooled_hidden_states: bool = False,
):
# Input and output info
self.rid = rid
@@ -872,6 +873,10 @@ class Req(ReqDllmMixin):
# For Matryoshka embeddings
self.dimensions = dimensions
# Whether to return pooled hidden states (pre-head transformer output)
self.return_pooled_hidden_states = return_pooled_hidden_states
self.pooled_hidden_state = None
# For diffusion LLM
self.init_diffusion_llm(dllm_config)
@@ -1403,6 +1408,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# For matryoshka embeddings
dimensions: Optional[list[int]] = None
# Whether to return pooled hidden states (pre-head transformer output)
return_pooled_hidden_states: bool = False
# For split prefill
split_index: int = 0
split_prefill_finished: bool = False
@@ -1594,6 +1602,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
for r in reqs
]
# OR across the batch so ForwardBatch matches a single fused forward; requests
# that did not ask for PHS still skip attaching it in the output processor.
self.return_pooled_hidden_states = any(
r.return_pooled_hidden_states for r in reqs
)
token_type_ids = [
r.token_type_ids for r in reqs if r.token_type_ids is not None
]
@@ -2439,6 +2453,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
is_prefill_only=self.is_prefill_only,
dimensions=self.dimensions,
return_pooled_hidden_states=self.return_pooled_hidden_states,
dllm_block_offsets=[req.dllm_block_offset for req in self.reqs],
dllm_config=self.dllm_config,
reqs=self.reqs,
@@ -2632,6 +2647,9 @@ class ModelWorkerBatch:
# For matryoshka embeddings
dimensions: Optional[list[int]] = None
# Whether to return pooled hidden states (pre-head transformer output)
return_pooled_hidden_states: bool = False
# Whether this batch is prefill-only (no token generation needed)
is_prefill_only: bool = False

View File

@@ -254,12 +254,22 @@ _is_npu = is_npu()
@dataclass
class EmbeddingBatchResult:
"""Result from an embedding/classification forward pass.
Attributes:
embeddings: Model output — pooled embeddings or classification logits.
pooled_hidden_states: Raw hidden states before the task head. Present
only when the batch contained ``return_pooled_hidden_states=True``
requests. Tensor (uniform shapes) or list of tensors (MIS).
copy_done: CUDA event recorded after the async CPU copy completes.
"""
embeddings: torch.Tensor
pooled_hidden_states: Optional[torch.Tensor] = None
copy_done: Optional[torch.cuda.Event] = None
def copy_to_cpu(self):
"""Copy embeddings tensor to CPU in overlap scheduling."""
"""Copy embeddings and pooled hidden states to CPU for overlap scheduling."""
if isinstance(self.embeddings, torch.Tensor):
self.copy_done = torch.get_device_module(self.embeddings.device).Event()
self.embeddings = self.embeddings.to("cpu", non_blocking=True)
@@ -273,6 +283,16 @@ class EmbeddingBatchResult:
emb.to("cpu", non_blocking=True) for emb in self.embeddings
]
if self.pooled_hidden_states is not None:
if isinstance(self.pooled_hidden_states, list):
self.pooled_hidden_states = [
t.to("cpu", non_blocking=True) for t in self.pooled_hidden_states
]
else:
self.pooled_hidden_states = self.pooled_hidden_states.to(
"cpu", non_blocking=True
)
self.copy_done.record()
@@ -2157,6 +2177,7 @@ class Scheduler(
lora_id=recv_req.lora_id,
http_worker_ipc=recv_req.http_worker_ipc,
time_stats=recv_req.time_stats,
return_pooled_hidden_states=recv_req.return_pooled_hidden_states,
)
req.tokenizer = self.tokenizer
@@ -2828,14 +2849,22 @@ class Scheduler(
self.record_batch_in_overlap(model_worker_batch)
with self.forward_stream_ctx, self.record_bubble_metrics(batch):
self.forward_stream.wait_stream(self.schedule_stream)
embeddings = self.tp_worker.forward_batch_embedding(
pooler_output = self.tp_worker.forward_batch_embedding(
model_worker_batch
)
ret = EmbeddingBatchResult(embeddings=embeddings)
ret = EmbeddingBatchResult(
embeddings=pooler_output.embeddings,
pooled_hidden_states=pooler_output.pooled_hidden_states,
)
ret.copy_to_cpu()
else:
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
ret = EmbeddingBatchResult(embeddings=embeddings)
pooler_output = self.tp_worker.forward_batch_embedding(
model_worker_batch
)
ret = EmbeddingBatchResult(
embeddings=pooler_output.embeddings,
pooled_hidden_states=pooler_output.pooled_hidden_states,
)
# Capture prefill end time for EXTEND mode
if batch.forward_mode == ForwardMode.EXTEND:

View File

@@ -286,6 +286,7 @@ class SchedulerOutputProcessorMixin:
is_sparse = envs.SGLANG_EMBEDDINGS_SPARSE_HEAD.is_set()
embeddings = result.embeddings
phs = result.pooled_hidden_states
if is_sparse:
batch_ids, token_ids = embeddings.indices()
@@ -302,12 +303,20 @@ class SchedulerOutputProcessorMixin:
else:
embeddings = [tensor.tolist() for tensor in embeddings]
if phs is not None:
if isinstance(phs, list):
phs = [t.cpu().detach() for t in phs]
else:
phs = phs.cpu().detach()
# Check finish conditions
for i, req in enumerate(batch.reqs):
if req.is_retracted:
continue
req.embedding = embeddings[i]
if req.return_pooled_hidden_states and phs is not None:
req.pooled_hidden_state = phs[i]
if req.is_chunked <= 0:
req.time_stats.set_prefill_finished_time()
# Dummy output token for embedding models
@@ -1215,6 +1224,8 @@ class SchedulerOutputProcessorMixin:
cached_tokens_details = [] # Detailed breakdown by cache source
time_stats = []
retraction_counts = []
phs_list = []
has_phs = False
for req in reqs:
if req.finished():
rids.append(req.rid)
@@ -1228,6 +1239,28 @@ class SchedulerOutputProcessorMixin:
cached_tokens_details.append(self._get_cached_tokens_details(req))
time_stats.append(req.time_stats)
retraction_counts.append(req.retraction_count)
phs = req.pooled_hidden_state
phs_list.append(phs)
if phs is not None:
has_phs = True
# Optimize PHS for pickle: torch.stack reduces N __reduce_ex__
# calls to 1 across the ZMQ IPC boundary. We can only stack when
# *every* entry is non-None (homogeneous batch); mixed batches
# (some requests want PHS, others don't) keep the raw list so
# positional indexing on the receiver side stays correct.
stacked_phs = None
if has_phs:
all_have_phs = all(t is not None for t in phs_list)
if all_have_phs:
if all(t.shape == phs_list[0].shape for t in phs_list):
stacked_phs = torch.stack(phs_list)
else:
stacked_phs = phs_list
else:
stacked_phs = phs_list
self.send_to_detokenizer.send_output(
BatchEmbeddingOutput(
rids=rids,
@@ -1241,5 +1274,6 @@ class SchedulerOutputProcessorMixin:
placeholder_tokens_idx=None,
placeholder_tokens_val=None,
retraction_counts=retraction_counts,
pooled_hidden_states=stacked_phs,
)
)

View File

@@ -1032,6 +1032,7 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerScoreMixin):
dimensions=obj.dimensions,
lora_id=obj.lora_id,
http_worker_ipc=obj.http_worker_ipc,
return_pooled_hidden_states=obj.return_pooled_hidden_states,
)
tokenized_obj.time_stats = self.rid_to_state[obj.rid].time_stats
@@ -1776,6 +1777,11 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerScoreMixin):
"embedding": recv_obj.embeddings[i],
"meta_info": meta_info,
}
if (
recv_obj.pooled_hidden_states is not None
and recv_obj.pooled_hidden_states[i] is not None
):
out_dict["pooled_hidden_state"] = recv_obj.pooled_hidden_states[i]
# Set first_token_time on the first output batch.
# This is the single write point for first_token_time.

View File

@@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from sglang.srt.configs.model_config import is_cross_encoding_pooler_model
from sglang.srt.managers.embed_types import PositionalEmbeds
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
@@ -15,6 +16,12 @@ logger = logging.getLogger(__name__)
class ScoreResult:
scores: List[List[float]]
prompt_tokens: int = 0
# Per-item pooled hidden states (pre-head transformer output).
# CPU tensors when return_pooled_hidden_states=True; kept as tensors so
# in-process consumers (gRPC, engine API) avoid a .tolist() round-trip.
# The HTTP path converts to lists in serving_score.py before JSON serialization.
# Same layout as scores: one tensor per item (not a single packed 2D tensor).
pooled_hidden_states: Optional[List[Optional[torch.Tensor]]] = None
class TokenizerManagerScoreMixin:
@@ -151,6 +158,7 @@ class TokenizerManagerScoreMixin:
label_token_ids: Optional[List[int]],
apply_softmax: bool,
batch_request=None,
return_pooled_hidden_states: bool = False,
) -> ScoreResult:
"""
Process results from multi-item scoring request.
@@ -166,11 +174,13 @@ class TokenizerManagerScoreMixin:
label_token_ids: Token IDs to extract scores for
apply_softmax: Whether to apply softmax normalization
batch_request: The original batch request containing input sequence
return_pooled_hidden_states: Whether to extract pooled hidden states
from the result and include them in the ScoreResult.
Returns:
ScoreResult with:
scores: List of score lists, one for each prompt, each in the order of label_token_ids.
prompt_tokens: The number of prompt tokens processed.
ScoreResult with per-item scores, prompt token count, and optional
pooled_hidden_states (when return_pooled_hidden_states=True and the
model populated the field).
"""
single_result = results[0] if isinstance(results, list) else results
meta_info = single_result.get("meta_info", {})
@@ -225,10 +235,24 @@ class TokenizerManagerScoreMixin:
# Skip the first delimiter (query-item boundary)
scores = per_delimiter_scores[1:]
return ScoreResult(scores=scores, prompt_tokens=prompt_tokens)
phs_list = None
if return_pooled_hidden_states:
raw_phs = single_result.get("pooled_hidden_state")
if raw_phs is not None and len(raw_phs) == expected_count:
phs_list = raw_phs[1:]
return ScoreResult(
scores=scores,
prompt_tokens=prompt_tokens,
pooled_hidden_states=phs_list,
)
def _process_single_item_scoring_results(
self, results: Any, label_token_ids: Optional[List[int]], apply_softmax: bool
self,
results: Any,
label_token_ids: Optional[List[int]],
apply_softmax: bool,
return_pooled_hidden_states: bool = False,
) -> ScoreResult:
"""
Process results from single-item scoring request.
@@ -241,13 +265,14 @@ class TokenizerManagerScoreMixin:
results: Results from generate_request
label_token_ids: Token IDs to extract scores for (generation models only)
apply_softmax: Whether to apply softmax normalization
return_pooled_hidden_states: Whether to extract pooled hidden states
Returns:
ScoreResult with:
scores: List of score lists, one for each prompt, each in the order of label_token_ids.
prompt_tokens: The number of prompt tokens processed.
ScoreResult with per-item scores, prompt token count, and optional pooled_hidden_states.
"""
scores = []
phs_list = []
has_phs = False
prompt_tokens = 0
is_generation = getattr(self, "is_generation", True)
@@ -293,7 +318,17 @@ class TokenizerManagerScoreMixin:
# EmbeddingPoolerOutput API.
scores.append(embedding)
return ScoreResult(scores=scores, prompt_tokens=prompt_tokens)
if return_pooled_hidden_states:
phs = result.get("pooled_hidden_state")
phs_list.append(phs)
if phs is not None:
has_phs = True
return ScoreResult(
scores=scores,
prompt_tokens=prompt_tokens,
pooled_hidden_states=phs_list if has_phs else None,
)
# ------------------------------------------------------------------
# Embed override position resolution
@@ -481,6 +516,7 @@ class TokenizerManagerScoreMixin:
query_embed_overrides: Optional[List[torch.Tensor]] = None,
item_embed_overrides: Optional[List[Optional[List[torch.Tensor]]]] = None,
request: Optional[Any] = None,
return_pooled_hidden_states: bool = False,
) -> ScoreResult:
"""
Score the probability of specified token IDs appearing after the given (query + item) pair.
@@ -510,11 +546,18 @@ class TokenizerManagerScoreMixin:
query_embed_overrides: Embedding vectors replacing placeholder tokens in query.
item_embed_overrides: Per-item embedding vectors replacing placeholder tokens in items.
request: Optional FastAPI request object
return_pooled_hidden_states: Whether to include the raw pooled transformer
hidden states (before the task-specific head) in the result. Only
supported for non-generation models (SequenceClassification,
RewardModel). Raises ValueError for CausalLM models.
Returns:
ScoreResult with:
scores: List of score lists, one for each prompt, each in the order of label_token_ids.
scores: List of score lists, one per item.
prompt_tokens: The number of prompt tokens processed.
pooled_hidden_states: Per-item CPU tensors when
return_pooled_hidden_states=True and the model supports it;
None otherwise.
"""
is_generation = getattr(self, "is_generation", True)
@@ -618,6 +661,23 @@ class TokenizerManagerScoreMixin:
"Invalid combination of query/items types for score_request."
)
if return_pooled_hidden_states:
if is_generation:
raise ValueError(
"return_pooled_hidden_states is not supported for CausalLM models. "
"It requires a model with a task-specific head "
"(e.g. SequenceClassification or RewardModel)."
)
model_config = getattr(self, "model_config", None)
if model_config is not None:
archs = getattr(model_config.hf_config, "architectures", []) or []
if is_cross_encoding_pooler_model(archs):
raise ValueError(
f"return_pooled_hidden_states is not supported for "
f"{archs[0]}. This model uses CrossEncodingPooler which "
f"does not expose pre-head hidden states."
)
# Create the appropriate request type
if is_generation:
batch_request = GenerateReqInput(
@@ -636,6 +696,7 @@ class TokenizerManagerScoreMixin:
text=text_prompts,
input_ids=input_ids,
positional_embed_overrides=positional_embed_overrides,
return_pooled_hidden_states=return_pooled_hidden_states,
)
results = await self.generate_request(batch_request, request).__anext__()
@@ -643,12 +704,17 @@ class TokenizerManagerScoreMixin:
if use_multi_item_scoring:
# Multi-item scoring: extract scores from input_token_ids_logprobs or embedding
return self._process_multi_item_scoring_results(
results, items, label_token_ids, apply_softmax, batch_request
results,
items,
label_token_ids,
apply_softmax,
batch_request,
return_pooled_hidden_states,
)
else:
# Single-item scoring: process each result separately
return self._process_single_item_scoring_results(
results, label_token_ids, apply_softmax
results, label_token_ids, apply_softmax, return_pooled_hidden_states
)
def _convert_logprobs_to_scores(

View File

@@ -210,9 +210,8 @@ class BaseTpWorker(ABC):
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
logits_output = self.model_runner.forward(forward_batch).logits_output
embeddings = logits_output.embeddings
return embeddings
output = self.model_runner.forward(forward_batch).logits_output
return output # Returns EmbeddingPoolerOutput
class TpModelWorker(BaseTpWorker):

View File

@@ -427,6 +427,9 @@ class ForwardBatch(ForwardBatchDeepSeekMHAMixin):
# For hidden states before normal
return_hidden_states_before_norm: bool = False
# Whether to return pooled hidden states (pre-head transformer output)
return_pooled_hidden_states: bool = False
# For hisparse
hisparse_coordinator: Optional[HiSparseCoordinator] = None
@@ -483,6 +486,7 @@ class ForwardBatch(ForwardBatchDeepSeekMHAMixin):
tbo_split_seq_index=batch.tbo_split_seq_index,
dimensions=batch.dimensions,
return_hidden_states_before_norm=batch.return_hidden_states_before_norm,
return_pooled_hidden_states=batch.return_pooled_hidden_states,
rids=[req.rid for req in batch.reqs],
)
device = model_runner.device

View File

@@ -206,6 +206,9 @@ class PiecewiseCudaGraphRunner:
self.is_multimodal = model_runner.is_multimodal
self.mamba_track_enabled = self.is_mamba_track_enabled()
# Classification/reward forwards branch on return_pooled_hidden_states; piecewise
# CUDA graph capture must use the same flag value as replay for those models.
self.capture_return_pooled_hidden_states = not model_runner.is_generation
# Graph inputs
with torch.device(self.device):
@@ -390,6 +393,7 @@ class PiecewiseCudaGraphRunner:
num_token_non_padded_cpu=num_tokens,
global_forward_mode=ForwardMode.EXTEND,
lora_ids=None,
return_pooled_hidden_states=self.capture_return_pooled_hidden_states,
)
# Attention backend
@@ -551,6 +555,7 @@ class PiecewiseCudaGraphRunner:
num_token_non_padded_cpu=num_tokens,
global_forward_mode=ForwardMode.EXTEND,
lora_ids=None,
return_pooled_hidden_states=self.capture_return_pooled_hidden_states,
)
self.tbo_plugin.capture_one_batch_size(forward_batch, num_tokens=num_tokens)
@@ -748,6 +753,10 @@ class PiecewiseCudaGraphRunner:
top_p_normalized_logprobs=forward_batch.top_p_normalized_logprobs,
top_p=forward_batch.top_p,
dimensions=forward_batch.dimensions,
return_pooled_hidden_states=(
self.capture_return_pooled_hidden_states
or forward_batch.return_pooled_hidden_states
),
)
if out_cache_loc_swa is not None:

View File

@@ -61,7 +61,12 @@ class Gemma2ForSequenceClassification(nn.Module):
last_token_hidden = self.pooler(hidden_states, forward_batch).embeddings
scores = self.score(last_token_hidden)
return EmbeddingPoolerOutput(scores)
return EmbeddingPoolerOutput(
embeddings=scores,
pooled_hidden_states=(
last_token_hidden if forward_batch.return_pooled_hidden_states else None
),
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
Gemma2ForCausalLM.load_weights(self, weights)

View File

@@ -55,7 +55,12 @@ class InternLM2ForRewardModel(nn.Module):
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
last_token_hidden = self.pooler(hidden_states, forward_batch).embeddings
scores = self.v_head(last_token_hidden)
return EmbeddingPoolerOutput(scores)
return EmbeddingPoolerOutput(
embeddings=scores,
pooled_hidden_states=(
last_token_hidden if forward_batch.return_pooled_hidden_states else None
),
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
return InternLM2ForCausalLM.load_weights(self, weights)

View File

@@ -18,7 +18,12 @@ import torch
from torch import nn
from transformers import LlamaConfig
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
from sglang.srt.layers.pooler import (
EmbeddingPoolerOutput,
Pooler,
PoolingType,
pool_hidden_states,
)
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
@@ -61,7 +66,12 @@ class LlamaForSequenceClassification(nn.Module):
last_token_hidden = self.pooler(hidden_states, forward_batch).embeddings
scores = self.score(last_token_hidden)
return EmbeddingPoolerOutput(scores)
return EmbeddingPoolerOutput(
embeddings=scores,
pooled_hidden_states=(
last_token_hidden if forward_batch.return_pooled_hidden_states else None
),
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
return LlamaForCausalLM.load_weights(self, weights)
@@ -114,7 +124,17 @@ class LlamaForSequenceClassificationWithNormal_Weights(LlamaForSequenceClassific
-1, self.num_labels // 2
)
scores = (rews * pooled_weights).sum(dim=-1).view(-1, 1)
return EmbeddingPoolerOutput(scores)
pooled_hidden = None
if forward_batch.return_pooled_hidden_states:
pooled_hidden = pool_hidden_states(
self.pooler.pooling_type, hidden_states, forward_batch
)
return EmbeddingPoolerOutput(
embeddings=scores,
pooled_hidden_states=pooled_hidden,
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
return super().load_weights(weights)

View File

@@ -18,7 +18,12 @@ import torch
from torch import nn
from transformers import Qwen2Config
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
from sglang.srt.layers.pooler import (
EmbeddingPoolerOutput,
Pooler,
PoolingType,
pool_hidden_states,
)
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.qwen2 import Qwen2ForCausalLM, Qwen2Model
@@ -63,7 +68,16 @@ class Qwen2ForRewardModel(nn.Module):
logits = self.score(hidden_states)
pooled_logits = self.pooler(logits, forward_batch).embeddings
return EmbeddingPoolerOutput(pooled_logits)
pooled_hidden = None
if forward_batch.return_pooled_hidden_states:
pooled_hidden = pool_hidden_states(
self.pooler.pooling_type, hidden_states, forward_batch
)
return EmbeddingPoolerOutput(
embeddings=pooled_logits,
pooled_hidden_states=pooled_hidden,
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# Filter out lm_head weights of Qwen2ForCausalLM

View File

@@ -800,6 +800,9 @@ class ServerArgs:
# Handle piecewise CUDA graph.
self._handle_piecewise_cuda_graph()
# Handle multi-item scoring constraints.
self._handle_multi_item_scoring()
# Get GPU memory capacity, which is a common dependency for several configuration steps.
gpu_mem = get_device_memory_capacity(self.device)
@@ -1205,6 +1208,21 @@ class ServerArgs:
if self.debug_cuda_graph:
self.disable_piecewise_cuda_graph = True
def _handle_multi_item_scoring(self):
"""Disable CUDA graphs when multi-item scoring delimiter is set.
The padded static input_ids buffer used by CUDA graph replay causes
spurious delimiter matches in score_and_pool's MIS path.
"""
if self.multi_item_scoring_delimiter is None:
return
if not self.disable_cuda_graph:
logger.warning(
"CUDA graph is disabled because --multi-item-scoring-delimiter is set."
)
self.disable_cuda_graph = True
self.disable_piecewise_cuda_graph = True
def _handle_gpu_memory_settings(self, gpu_mem):
"""
Configure GPU memory-dependent settings including