mirror of
https://github.com/kvcache-ai/sglang.git
synced 2026-07-01 04:08:10 +00:00
[Score API] Add return_pooled_hidden_states to Scoring API for SequenceClassification / RewardModel (#22427)
This commit is contained in:
committed by
GitHub
parent
4e480d5785
commit
4927975427
@@ -702,6 +702,7 @@ class TboForwardBatchPreparer:
|
||||
"mrope_positions", # only used by qwen2-vl, thus not care
|
||||
"split_index", # for split prefill
|
||||
"orig_seq_lens", # only used by qwen-1m, thus not care
|
||||
"return_pooled_hidden_states",
|
||||
]:
|
||||
output_dict[key] = getattr(batch, key)
|
||||
if not batch.forward_mode.is_target_verify():
|
||||
|
||||
@@ -1436,6 +1436,17 @@ def is_piecewise_cuda_graph_disabled_model(model_architectures: List[str]):
|
||||
)
|
||||
|
||||
|
||||
# SequenceClassification models that use CrossEncodingPooler
|
||||
_cross_encoding_pooler_archs = [
|
||||
"BertForSequenceClassification",
|
||||
"XLMRobertaForSequenceClassification",
|
||||
]
|
||||
|
||||
|
||||
def is_cross_encoding_pooler_model(model_architectures: List[str]) -> bool:
|
||||
return any(arch in _cross_encoding_pooler_archs for arch in model_architectures)
|
||||
|
||||
|
||||
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
|
||||
@@ -33,12 +33,10 @@ class EngineScoreMixin:
|
||||
label_token_ids: Optional[List[int]] = None,
|
||||
apply_softmax: bool = False,
|
||||
item_first: bool = False,
|
||||
# Placeholder token id in query/items that indicates override locations.
|
||||
embed_override_token_id: Optional[int] = None,
|
||||
# Query embedding overrides.
|
||||
query_embed_overrides: Optional[List[torch.Tensor]] = None,
|
||||
# Item embedding overrides: per-item list of tensors.
|
||||
item_embed_overrides: Optional[List[Optional[List[torch.Tensor]]]] = None,
|
||||
return_pooled_hidden_states: bool = False,
|
||||
) -> ScoreResult:
|
||||
"""
|
||||
Score items against a query using the loaded model.
|
||||
@@ -63,9 +61,13 @@ class EngineScoreMixin:
|
||||
embed_override_token_id: Placeholder token ID used to locate override positions.
|
||||
query_embed_overrides: Embedding vectors replacing placeholder tokens in query.
|
||||
item_embed_overrides: Per-item embedding vectors replacing placeholder tokens in items.
|
||||
return_pooled_hidden_states: Whether to include raw pooled transformer
|
||||
hidden states (before the task head) in the result. Only supported
|
||||
for non-generation models (SequenceClassification, RewardModel).
|
||||
|
||||
Returns:
|
||||
ScoreResult with scores (one list per item) and prompt token count.
|
||||
ScoreResult with scores (one list per item), prompt token count, and
|
||||
optional pooled_hidden_states tensors.
|
||||
"""
|
||||
return self.loop.run_until_complete(
|
||||
self.tokenizer_manager.score_request(
|
||||
@@ -78,6 +80,7 @@ class EngineScoreMixin:
|
||||
query_embed_overrides=query_embed_overrides,
|
||||
item_embed_overrides=item_embed_overrides,
|
||||
request=None,
|
||||
return_pooled_hidden_states=return_pooled_hidden_states,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -91,6 +94,7 @@ class EngineScoreMixin:
|
||||
embed_override_token_id: Optional[int] = None,
|
||||
query_embed_overrides: Optional[List[torch.Tensor]] = None,
|
||||
item_embed_overrides: Optional[List[Optional[List[torch.Tensor]]]] = None,
|
||||
return_pooled_hidden_states: bool = False,
|
||||
) -> ScoreResult:
|
||||
"""Asynchronous version of score(). See score() for full documentation."""
|
||||
return await self.tokenizer_manager.score_request(
|
||||
@@ -103,4 +107,5 @@ class EngineScoreMixin:
|
||||
query_embed_overrides=query_embed_overrides,
|
||||
item_embed_overrides=item_embed_overrides,
|
||||
request=None,
|
||||
return_pooled_hidden_states=return_pooled_hidden_states,
|
||||
)
|
||||
|
||||
@@ -1016,6 +1016,7 @@ class ScoringRequest(BaseModel):
|
||||
)
|
||||
apply_softmax: bool = False
|
||||
item_first: bool = False
|
||||
return_pooled_hidden_states: bool = False
|
||||
model: str = DEFAULT_MODEL_NAME
|
||||
|
||||
|
||||
@@ -1023,6 +1024,7 @@ class ScoringResponse(BaseModel):
|
||||
scores: List[
|
||||
List[float]
|
||||
] # List of lists of probabilities, each in the order of label_token_ids
|
||||
pooled_hidden_states: Optional[List[Optional[List[float]]]] = None
|
||||
model: str
|
||||
usage: Optional[UsageInfo] = None
|
||||
object: str = "scoring"
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Union
|
||||
|
||||
import torch
|
||||
from fastapi import Request
|
||||
from fastapi.responses import ORJSONResponse
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import (
|
||||
ErrorResponse,
|
||||
@@ -76,17 +77,26 @@ class OpenAIServingScore(OpenAIServingBase):
|
||||
query_embed_overrides=query_embed_overrides,
|
||||
item_embed_overrides=item_embed_overrides,
|
||||
request=raw_request,
|
||||
return_pooled_hidden_states=request.return_pooled_hidden_states,
|
||||
)
|
||||
|
||||
phs_as_lists = None
|
||||
if result.pooled_hidden_states is not None:
|
||||
phs_as_lists = [
|
||||
t.tolist() if t is not None else None
|
||||
for t in result.pooled_hidden_states
|
||||
]
|
||||
|
||||
response = ScoringResponse(
|
||||
scores=result.scores,
|
||||
pooled_hidden_states=phs_as_lists,
|
||||
model=request.model,
|
||||
usage=UsageInfo(
|
||||
prompt_tokens=result.prompt_tokens,
|
||||
total_tokens=result.prompt_tokens,
|
||||
),
|
||||
)
|
||||
return response
|
||||
return ORJSONResponse(content=response.model_dump())
|
||||
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
@@ -1,18 +1,22 @@
|
||||
# adapted from
|
||||
# https://github.com/vllm-project/vllm/blob/82a1b1a82b1fbb454c82a9ef95730b929c9b270c/vllm/model_executor/layers/pooler.py
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from sglang.srt.layers.activation import get_cross_encoder_activation_function
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
|
||||
class PoolingType(IntEnum):
|
||||
LAST = 0
|
||||
@@ -21,9 +25,45 @@ class PoolingType(IntEnum):
|
||||
|
||||
@dataclass
|
||||
class EmbeddingPoolerOutput:
|
||||
"""Output of pooler or score_and_pool.
|
||||
|
||||
Attributes:
|
||||
embeddings: Pooled embeddings or classification logits. May be a list
|
||||
of tensors when per-request matryoshka dim truncation produces
|
||||
different shapes, or when MIS yields a variable number of scores
|
||||
per request.
|
||||
pooled_hidden_states: Raw transformer hidden states *before* the
|
||||
task-specific head, present only when
|
||||
``forward_batch.return_pooled_hidden_states`` is True. Tensor
|
||||
(standard path) or list of tensors (MIS path, one per delimiter).
|
||||
"""
|
||||
|
||||
# Pooler can return list[tensor] instead of tensor if the dimension of each tensor in the batch is different
|
||||
# due to different per-request matryoshka dim truncation
|
||||
embeddings: torch.Tensor | list[torch.Tensor]
|
||||
pooled_hidden_states: Optional[torch.Tensor | list[torch.Tensor]] = None
|
||||
|
||||
|
||||
def pool_hidden_states(
|
||||
pooling_type: PoolingType,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
"""Pool hidden_states by PoolingType (LAST/CLS).
|
||||
|
||||
Raw pooling only — no normalize, no dim truncation.
|
||||
Returns shape (batch_size, hidden_size).
|
||||
"""
|
||||
if pooling_type == PoolingType.LAST:
|
||||
last_token_indices = torch.cumsum(forward_batch.extend_seq_lens, dim=0) - 1
|
||||
return hidden_states[last_token_indices]
|
||||
elif pooling_type == PoolingType.CLS:
|
||||
prompt_lens = forward_batch.extend_seq_lens
|
||||
first_token_flat_indices = torch.zeros_like(prompt_lens)
|
||||
first_token_flat_indices[1:] += torch.cumsum(prompt_lens, dim=0)[:-1]
|
||||
return hidden_states[first_token_flat_indices]
|
||||
else:
|
||||
raise ValueError(f"Unsupported pooling type: {pooling_type}")
|
||||
|
||||
|
||||
def score_and_pool(
|
||||
@@ -33,16 +73,16 @@ def score_and_pool(
|
||||
forward_batch: ForwardBatch,
|
||||
input_ids: torch.Tensor,
|
||||
) -> EmbeddingPoolerOutput:
|
||||
"""Apply a classification/score head with multi-item scoring (MIS) support.
|
||||
"""Apply a classification/score head with MIS and pooled-hidden-states support.
|
||||
|
||||
When ``multi_item_scoring_delimiter`` is configured and found in
|
||||
``input_ids``, takes the MIS path: extract hidden states at the positions
|
||||
just before each delimiter, apply the score head only to those positions,
|
||||
then split results per-request using ``forward_batch.extend_seq_lens``.
|
||||
MIS path (when ``multi_item_scoring_delimiter`` is set and found in ``input_ids``):
|
||||
extract hidden states at positions just before each delimiter, apply the score head,
|
||||
then split per-request.
|
||||
|
||||
Otherwise, takes the normal single-item path: apply the score head to all
|
||||
hidden states, then pool (matching the original classification model
|
||||
forward logic).
|
||||
Standard path: apply the score head to all hidden states, then pool.
|
||||
|
||||
When ``forward_batch.return_pooled_hidden_states`` is True, the raw pooled
|
||||
hidden states (before the score head) are included in the output.
|
||||
"""
|
||||
delimiter_token = get_global_server_args().multi_item_scoring_delimiter
|
||||
if delimiter_token is not None and forward_batch.is_prefill_only:
|
||||
@@ -52,26 +92,40 @@ def score_and_pool(
|
||||
|
||||
if delim_positions.numel() > 0:
|
||||
# Score only the tokens that precede a delimiter
|
||||
scores = score_head(hidden_states[delim_positions - 1])
|
||||
pre_delim_hidden = hidden_states[delim_positions - 1]
|
||||
scores = score_head(pre_delim_hidden)
|
||||
|
||||
# Split per-request so the scheduler gets one tensor per request.
|
||||
# Use CPU sequence lengths to avoid per-iteration GPU<->CPU sync
|
||||
# from `.item()` calls on device tensors.
|
||||
seq_lens = forward_batch.extend_seq_lens_cpu
|
||||
start = 0
|
||||
per_request = []
|
||||
per_request_scores: List[torch.Tensor] = []
|
||||
per_request_phs: Optional[List[torch.Tensor]] = (
|
||||
[] if forward_batch.return_pooled_hidden_states else None
|
||||
)
|
||||
for seq_len in seq_lens:
|
||||
end = start + seq_len
|
||||
mask = (delim_positions >= start) & (delim_positions < end)
|
||||
per_request.append(scores[mask])
|
||||
per_request_scores.append(scores[mask])
|
||||
if per_request_phs is not None:
|
||||
per_request_phs.append(pre_delim_hidden[mask])
|
||||
start = end
|
||||
|
||||
return EmbeddingPoolerOutput(embeddings=per_request)
|
||||
return EmbeddingPoolerOutput(
|
||||
embeddings=per_request_scores,
|
||||
pooled_hidden_states=per_request_phs,
|
||||
)
|
||||
|
||||
# Standard classification path: score all tokens, then pool.
|
||||
logits = score_head(hidden_states)
|
||||
pooled_logits = pooler(logits, forward_batch).embeddings
|
||||
return EmbeddingPoolerOutput(embeddings=pooled_logits)
|
||||
# Standard classification path: pool hidden states, then score.
|
||||
pooled_hs = pool_hidden_states(pooler.pooling_type, hidden_states, forward_batch)
|
||||
scores = score_head(pooled_hs)
|
||||
return EmbeddingPoolerOutput(
|
||||
embeddings=scores,
|
||||
pooled_hidden_states=(
|
||||
pooled_hs if forward_batch.return_pooled_hidden_states else None
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class Pooler(nn.Module):
|
||||
@@ -93,17 +147,9 @@ class Pooler(nn.Module):
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
|
||||
) -> EmbeddingPoolerOutput:
|
||||
|
||||
if self.pooling_type == PoolingType.LAST:
|
||||
last_token_indices = torch.cumsum(forward_batch.extend_seq_lens, dim=0) - 1
|
||||
pooled_data = hidden_states[last_token_indices]
|
||||
elif self.pooling_type == PoolingType.CLS:
|
||||
prompt_lens = forward_batch.extend_seq_lens
|
||||
first_token_flat_indices = torch.zeros_like(prompt_lens)
|
||||
first_token_flat_indices[1:] += torch.cumsum(prompt_lens, dim=0)[:-1]
|
||||
pooled_data = hidden_states[first_token_flat_indices]
|
||||
else:
|
||||
raise ValueError(f"Invalid pooling type: {self.pooling_type}")
|
||||
pooled_data = pool_hidden_states(
|
||||
self.pooling_type, hidden_states, forward_batch
|
||||
)
|
||||
|
||||
if forward_batch.dimensions is not None:
|
||||
all_same_dimensions = len(set(forward_batch.dimensions)) == 1
|
||||
|
||||
@@ -852,6 +852,9 @@ class EmbeddingReqInput(BaseReq):
|
||||
# The uid of LoRA adaptors, should be initialized by tokenizer manager
|
||||
lora_id: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
||||
|
||||
# Whether to return pooled hidden states (pre-head transformer output)
|
||||
return_pooled_hidden_states: bool = False
|
||||
|
||||
def normalize_batch_and_arguments(self):
|
||||
# at least one of text, input_ids, or image should be provided
|
||||
if self.text is None and self.input_ids is None and self.image_data is None:
|
||||
@@ -953,6 +956,7 @@ class EmbeddingReqInput(BaseReq):
|
||||
lora_id=self.lora_id[i] if self.lora_id is not None else None,
|
||||
is_cross_encoder_request=True,
|
||||
http_worker_ipc=self.http_worker_ipc,
|
||||
return_pooled_hidden_states=self.return_pooled_hidden_states,
|
||||
)
|
||||
else:
|
||||
sub = EmbeddingReqInput(
|
||||
@@ -976,6 +980,7 @@ class EmbeddingReqInput(BaseReq):
|
||||
dimensions=self.dimensions,
|
||||
http_worker_ipc=self.http_worker_ipc,
|
||||
received_time=self.received_time,
|
||||
return_pooled_hidden_states=self.return_pooled_hidden_states,
|
||||
)
|
||||
cache[i] = sub
|
||||
return sub
|
||||
@@ -1007,6 +1012,9 @@ class TokenizedEmbeddingReqInput(BaseReq):
|
||||
# For observability
|
||||
time_stats: Optional[Union[APIServerReqTimeStats, DPControllerReqTimeStats]] = None
|
||||
|
||||
# Whether to return pooled hidden states (pre-head transformer output)
|
||||
return_pooled_hidden_states: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchTokenizedEmbeddingReqInput(BaseBatchReq):
|
||||
@@ -1175,6 +1183,12 @@ class BatchEmbeddingOutput(BaseBatchReq):
|
||||
# For observability
|
||||
time_stats: Optional[List[SchedulerReqTimeStats]] = None
|
||||
|
||||
# Optional pooled hidden states (pre-head transformer output).
|
||||
# Sent as a single stacked tensor to minimize pickle overhead.
|
||||
pooled_hidden_states: Optional[
|
||||
Union[List[Optional[torch.Tensor]], torch.Tensor]
|
||||
] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClearHiCacheReqInput(BaseReq):
|
||||
|
||||
@@ -596,6 +596,7 @@ class Req(ReqDllmMixin):
|
||||
time_stats: Optional[
|
||||
Union[APIServerReqTimeStats, DPControllerReqTimeStats]
|
||||
] = None,
|
||||
return_pooled_hidden_states: bool = False,
|
||||
):
|
||||
# Input and output info
|
||||
self.rid = rid
|
||||
@@ -872,6 +873,10 @@ class Req(ReqDllmMixin):
|
||||
# For Matryoshka embeddings
|
||||
self.dimensions = dimensions
|
||||
|
||||
# Whether to return pooled hidden states (pre-head transformer output)
|
||||
self.return_pooled_hidden_states = return_pooled_hidden_states
|
||||
self.pooled_hidden_state = None
|
||||
|
||||
# For diffusion LLM
|
||||
self.init_diffusion_llm(dllm_config)
|
||||
|
||||
@@ -1403,6 +1408,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
# For matryoshka embeddings
|
||||
dimensions: Optional[list[int]] = None
|
||||
|
||||
# Whether to return pooled hidden states (pre-head transformer output)
|
||||
return_pooled_hidden_states: bool = False
|
||||
|
||||
# For split prefill
|
||||
split_index: int = 0
|
||||
split_prefill_finished: bool = False
|
||||
@@ -1594,6 +1602,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
for r in reqs
|
||||
]
|
||||
|
||||
# OR across the batch so ForwardBatch matches a single fused forward; requests
|
||||
# that did not ask for PHS still skip attaching it in the output processor.
|
||||
self.return_pooled_hidden_states = any(
|
||||
r.return_pooled_hidden_states for r in reqs
|
||||
)
|
||||
|
||||
token_type_ids = [
|
||||
r.token_type_ids for r in reqs if r.token_type_ids is not None
|
||||
]
|
||||
@@ -2439,6 +2453,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
|
||||
is_prefill_only=self.is_prefill_only,
|
||||
dimensions=self.dimensions,
|
||||
return_pooled_hidden_states=self.return_pooled_hidden_states,
|
||||
dllm_block_offsets=[req.dllm_block_offset for req in self.reqs],
|
||||
dllm_config=self.dllm_config,
|
||||
reqs=self.reqs,
|
||||
@@ -2632,6 +2647,9 @@ class ModelWorkerBatch:
|
||||
# For matryoshka embeddings
|
||||
dimensions: Optional[list[int]] = None
|
||||
|
||||
# Whether to return pooled hidden states (pre-head transformer output)
|
||||
return_pooled_hidden_states: bool = False
|
||||
|
||||
# Whether this batch is prefill-only (no token generation needed)
|
||||
is_prefill_only: bool = False
|
||||
|
||||
|
||||
@@ -254,12 +254,22 @@ _is_npu = is_npu()
|
||||
|
||||
@dataclass
|
||||
class EmbeddingBatchResult:
|
||||
"""Result from an embedding/classification forward pass.
|
||||
|
||||
Attributes:
|
||||
embeddings: Model output — pooled embeddings or classification logits.
|
||||
pooled_hidden_states: Raw hidden states before the task head. Present
|
||||
only when the batch contained ``return_pooled_hidden_states=True``
|
||||
requests. Tensor (uniform shapes) or list of tensors (MIS).
|
||||
copy_done: CUDA event recorded after the async CPU copy completes.
|
||||
"""
|
||||
|
||||
embeddings: torch.Tensor
|
||||
pooled_hidden_states: Optional[torch.Tensor] = None
|
||||
copy_done: Optional[torch.cuda.Event] = None
|
||||
|
||||
def copy_to_cpu(self):
|
||||
"""Copy embeddings tensor to CPU in overlap scheduling."""
|
||||
|
||||
"""Copy embeddings and pooled hidden states to CPU for overlap scheduling."""
|
||||
if isinstance(self.embeddings, torch.Tensor):
|
||||
self.copy_done = torch.get_device_module(self.embeddings.device).Event()
|
||||
self.embeddings = self.embeddings.to("cpu", non_blocking=True)
|
||||
@@ -273,6 +283,16 @@ class EmbeddingBatchResult:
|
||||
emb.to("cpu", non_blocking=True) for emb in self.embeddings
|
||||
]
|
||||
|
||||
if self.pooled_hidden_states is not None:
|
||||
if isinstance(self.pooled_hidden_states, list):
|
||||
self.pooled_hidden_states = [
|
||||
t.to("cpu", non_blocking=True) for t in self.pooled_hidden_states
|
||||
]
|
||||
else:
|
||||
self.pooled_hidden_states = self.pooled_hidden_states.to(
|
||||
"cpu", non_blocking=True
|
||||
)
|
||||
|
||||
self.copy_done.record()
|
||||
|
||||
|
||||
@@ -2157,6 +2177,7 @@ class Scheduler(
|
||||
lora_id=recv_req.lora_id,
|
||||
http_worker_ipc=recv_req.http_worker_ipc,
|
||||
time_stats=recv_req.time_stats,
|
||||
return_pooled_hidden_states=recv_req.return_pooled_hidden_states,
|
||||
)
|
||||
req.tokenizer = self.tokenizer
|
||||
|
||||
@@ -2828,14 +2849,22 @@ class Scheduler(
|
||||
self.record_batch_in_overlap(model_worker_batch)
|
||||
with self.forward_stream_ctx, self.record_bubble_metrics(batch):
|
||||
self.forward_stream.wait_stream(self.schedule_stream)
|
||||
embeddings = self.tp_worker.forward_batch_embedding(
|
||||
pooler_output = self.tp_worker.forward_batch_embedding(
|
||||
model_worker_batch
|
||||
)
|
||||
ret = EmbeddingBatchResult(embeddings=embeddings)
|
||||
ret = EmbeddingBatchResult(
|
||||
embeddings=pooler_output.embeddings,
|
||||
pooled_hidden_states=pooler_output.pooled_hidden_states,
|
||||
)
|
||||
ret.copy_to_cpu()
|
||||
else:
|
||||
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
|
||||
ret = EmbeddingBatchResult(embeddings=embeddings)
|
||||
pooler_output = self.tp_worker.forward_batch_embedding(
|
||||
model_worker_batch
|
||||
)
|
||||
ret = EmbeddingBatchResult(
|
||||
embeddings=pooler_output.embeddings,
|
||||
pooled_hidden_states=pooler_output.pooled_hidden_states,
|
||||
)
|
||||
|
||||
# Capture prefill end time for EXTEND mode
|
||||
if batch.forward_mode == ForwardMode.EXTEND:
|
||||
|
||||
@@ -286,6 +286,7 @@ class SchedulerOutputProcessorMixin:
|
||||
is_sparse = envs.SGLANG_EMBEDDINGS_SPARSE_HEAD.is_set()
|
||||
|
||||
embeddings = result.embeddings
|
||||
phs = result.pooled_hidden_states
|
||||
|
||||
if is_sparse:
|
||||
batch_ids, token_ids = embeddings.indices()
|
||||
@@ -302,12 +303,20 @@ class SchedulerOutputProcessorMixin:
|
||||
else:
|
||||
embeddings = [tensor.tolist() for tensor in embeddings]
|
||||
|
||||
if phs is not None:
|
||||
if isinstance(phs, list):
|
||||
phs = [t.cpu().detach() for t in phs]
|
||||
else:
|
||||
phs = phs.cpu().detach()
|
||||
|
||||
# Check finish conditions
|
||||
for i, req in enumerate(batch.reqs):
|
||||
if req.is_retracted:
|
||||
continue
|
||||
|
||||
req.embedding = embeddings[i]
|
||||
if req.return_pooled_hidden_states and phs is not None:
|
||||
req.pooled_hidden_state = phs[i]
|
||||
if req.is_chunked <= 0:
|
||||
req.time_stats.set_prefill_finished_time()
|
||||
# Dummy output token for embedding models
|
||||
@@ -1215,6 +1224,8 @@ class SchedulerOutputProcessorMixin:
|
||||
cached_tokens_details = [] # Detailed breakdown by cache source
|
||||
time_stats = []
|
||||
retraction_counts = []
|
||||
phs_list = []
|
||||
has_phs = False
|
||||
for req in reqs:
|
||||
if req.finished():
|
||||
rids.append(req.rid)
|
||||
@@ -1228,6 +1239,28 @@ class SchedulerOutputProcessorMixin:
|
||||
cached_tokens_details.append(self._get_cached_tokens_details(req))
|
||||
time_stats.append(req.time_stats)
|
||||
retraction_counts.append(req.retraction_count)
|
||||
|
||||
phs = req.pooled_hidden_state
|
||||
phs_list.append(phs)
|
||||
if phs is not None:
|
||||
has_phs = True
|
||||
|
||||
# Optimize PHS for pickle: torch.stack reduces N __reduce_ex__
|
||||
# calls to 1 across the ZMQ IPC boundary. We can only stack when
|
||||
# *every* entry is non-None (homogeneous batch); mixed batches
|
||||
# (some requests want PHS, others don't) keep the raw list so
|
||||
# positional indexing on the receiver side stays correct.
|
||||
stacked_phs = None
|
||||
if has_phs:
|
||||
all_have_phs = all(t is not None for t in phs_list)
|
||||
if all_have_phs:
|
||||
if all(t.shape == phs_list[0].shape for t in phs_list):
|
||||
stacked_phs = torch.stack(phs_list)
|
||||
else:
|
||||
stacked_phs = phs_list
|
||||
else:
|
||||
stacked_phs = phs_list
|
||||
|
||||
self.send_to_detokenizer.send_output(
|
||||
BatchEmbeddingOutput(
|
||||
rids=rids,
|
||||
@@ -1241,5 +1274,6 @@ class SchedulerOutputProcessorMixin:
|
||||
placeholder_tokens_idx=None,
|
||||
placeholder_tokens_val=None,
|
||||
retraction_counts=retraction_counts,
|
||||
pooled_hidden_states=stacked_phs,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1032,6 +1032,7 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerScoreMixin):
|
||||
dimensions=obj.dimensions,
|
||||
lora_id=obj.lora_id,
|
||||
http_worker_ipc=obj.http_worker_ipc,
|
||||
return_pooled_hidden_states=obj.return_pooled_hidden_states,
|
||||
)
|
||||
|
||||
tokenized_obj.time_stats = self.rid_to_state[obj.rid].time_stats
|
||||
@@ -1776,6 +1777,11 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerScoreMixin):
|
||||
"embedding": recv_obj.embeddings[i],
|
||||
"meta_info": meta_info,
|
||||
}
|
||||
if (
|
||||
recv_obj.pooled_hidden_states is not None
|
||||
and recv_obj.pooled_hidden_states[i] is not None
|
||||
):
|
||||
out_dict["pooled_hidden_state"] = recv_obj.pooled_hidden_states[i]
|
||||
|
||||
# Set first_token_time on the first output batch.
|
||||
# This is the single write point for first_token_time.
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.configs.model_config import is_cross_encoding_pooler_model
|
||||
from sglang.srt.managers.embed_types import PositionalEmbeds
|
||||
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
|
||||
|
||||
@@ -15,6 +16,12 @@ logger = logging.getLogger(__name__)
|
||||
class ScoreResult:
|
||||
scores: List[List[float]]
|
||||
prompt_tokens: int = 0
|
||||
# Per-item pooled hidden states (pre-head transformer output).
|
||||
# CPU tensors when return_pooled_hidden_states=True; kept as tensors so
|
||||
# in-process consumers (gRPC, engine API) avoid a .tolist() round-trip.
|
||||
# The HTTP path converts to lists in serving_score.py before JSON serialization.
|
||||
# Same layout as scores: one tensor per item (not a single packed 2D tensor).
|
||||
pooled_hidden_states: Optional[List[Optional[torch.Tensor]]] = None
|
||||
|
||||
|
||||
class TokenizerManagerScoreMixin:
|
||||
@@ -151,6 +158,7 @@ class TokenizerManagerScoreMixin:
|
||||
label_token_ids: Optional[List[int]],
|
||||
apply_softmax: bool,
|
||||
batch_request=None,
|
||||
return_pooled_hidden_states: bool = False,
|
||||
) -> ScoreResult:
|
||||
"""
|
||||
Process results from multi-item scoring request.
|
||||
@@ -166,11 +174,13 @@ class TokenizerManagerScoreMixin:
|
||||
label_token_ids: Token IDs to extract scores for
|
||||
apply_softmax: Whether to apply softmax normalization
|
||||
batch_request: The original batch request containing input sequence
|
||||
return_pooled_hidden_states: Whether to extract pooled hidden states
|
||||
from the result and include them in the ScoreResult.
|
||||
|
||||
Returns:
|
||||
ScoreResult with:
|
||||
scores: List of score lists, one for each prompt, each in the order of label_token_ids.
|
||||
prompt_tokens: The number of prompt tokens processed.
|
||||
ScoreResult with per-item scores, prompt token count, and optional
|
||||
pooled_hidden_states (when return_pooled_hidden_states=True and the
|
||||
model populated the field).
|
||||
"""
|
||||
single_result = results[0] if isinstance(results, list) else results
|
||||
meta_info = single_result.get("meta_info", {})
|
||||
@@ -225,10 +235,24 @@ class TokenizerManagerScoreMixin:
|
||||
# Skip the first delimiter (query-item boundary)
|
||||
scores = per_delimiter_scores[1:]
|
||||
|
||||
return ScoreResult(scores=scores, prompt_tokens=prompt_tokens)
|
||||
phs_list = None
|
||||
if return_pooled_hidden_states:
|
||||
raw_phs = single_result.get("pooled_hidden_state")
|
||||
if raw_phs is not None and len(raw_phs) == expected_count:
|
||||
phs_list = raw_phs[1:]
|
||||
|
||||
return ScoreResult(
|
||||
scores=scores,
|
||||
prompt_tokens=prompt_tokens,
|
||||
pooled_hidden_states=phs_list,
|
||||
)
|
||||
|
||||
def _process_single_item_scoring_results(
|
||||
self, results: Any, label_token_ids: Optional[List[int]], apply_softmax: bool
|
||||
self,
|
||||
results: Any,
|
||||
label_token_ids: Optional[List[int]],
|
||||
apply_softmax: bool,
|
||||
return_pooled_hidden_states: bool = False,
|
||||
) -> ScoreResult:
|
||||
"""
|
||||
Process results from single-item scoring request.
|
||||
@@ -241,13 +265,14 @@ class TokenizerManagerScoreMixin:
|
||||
results: Results from generate_request
|
||||
label_token_ids: Token IDs to extract scores for (generation models only)
|
||||
apply_softmax: Whether to apply softmax normalization
|
||||
return_pooled_hidden_states: Whether to extract pooled hidden states
|
||||
|
||||
Returns:
|
||||
ScoreResult with:
|
||||
scores: List of score lists, one for each prompt, each in the order of label_token_ids.
|
||||
prompt_tokens: The number of prompt tokens processed.
|
||||
ScoreResult with per-item scores, prompt token count, and optional pooled_hidden_states.
|
||||
"""
|
||||
scores = []
|
||||
phs_list = []
|
||||
has_phs = False
|
||||
prompt_tokens = 0
|
||||
|
||||
is_generation = getattr(self, "is_generation", True)
|
||||
@@ -293,7 +318,17 @@ class TokenizerManagerScoreMixin:
|
||||
# EmbeddingPoolerOutput API.
|
||||
scores.append(embedding)
|
||||
|
||||
return ScoreResult(scores=scores, prompt_tokens=prompt_tokens)
|
||||
if return_pooled_hidden_states:
|
||||
phs = result.get("pooled_hidden_state")
|
||||
phs_list.append(phs)
|
||||
if phs is not None:
|
||||
has_phs = True
|
||||
|
||||
return ScoreResult(
|
||||
scores=scores,
|
||||
prompt_tokens=prompt_tokens,
|
||||
pooled_hidden_states=phs_list if has_phs else None,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Embed override position resolution
|
||||
@@ -481,6 +516,7 @@ class TokenizerManagerScoreMixin:
|
||||
query_embed_overrides: Optional[List[torch.Tensor]] = None,
|
||||
item_embed_overrides: Optional[List[Optional[List[torch.Tensor]]]] = None,
|
||||
request: Optional[Any] = None,
|
||||
return_pooled_hidden_states: bool = False,
|
||||
) -> ScoreResult:
|
||||
"""
|
||||
Score the probability of specified token IDs appearing after the given (query + item) pair.
|
||||
@@ -510,11 +546,18 @@ class TokenizerManagerScoreMixin:
|
||||
query_embed_overrides: Embedding vectors replacing placeholder tokens in query.
|
||||
item_embed_overrides: Per-item embedding vectors replacing placeholder tokens in items.
|
||||
request: Optional FastAPI request object
|
||||
return_pooled_hidden_states: Whether to include the raw pooled transformer
|
||||
hidden states (before the task-specific head) in the result. Only
|
||||
supported for non-generation models (SequenceClassification,
|
||||
RewardModel). Raises ValueError for CausalLM models.
|
||||
|
||||
Returns:
|
||||
ScoreResult with:
|
||||
scores: List of score lists, one for each prompt, each in the order of label_token_ids.
|
||||
scores: List of score lists, one per item.
|
||||
prompt_tokens: The number of prompt tokens processed.
|
||||
pooled_hidden_states: Per-item CPU tensors when
|
||||
return_pooled_hidden_states=True and the model supports it;
|
||||
None otherwise.
|
||||
"""
|
||||
is_generation = getattr(self, "is_generation", True)
|
||||
|
||||
@@ -618,6 +661,23 @@ class TokenizerManagerScoreMixin:
|
||||
"Invalid combination of query/items types for score_request."
|
||||
)
|
||||
|
||||
if return_pooled_hidden_states:
|
||||
if is_generation:
|
||||
raise ValueError(
|
||||
"return_pooled_hidden_states is not supported for CausalLM models. "
|
||||
"It requires a model with a task-specific head "
|
||||
"(e.g. SequenceClassification or RewardModel)."
|
||||
)
|
||||
model_config = getattr(self, "model_config", None)
|
||||
if model_config is not None:
|
||||
archs = getattr(model_config.hf_config, "architectures", []) or []
|
||||
if is_cross_encoding_pooler_model(archs):
|
||||
raise ValueError(
|
||||
f"return_pooled_hidden_states is not supported for "
|
||||
f"{archs[0]}. This model uses CrossEncodingPooler which "
|
||||
f"does not expose pre-head hidden states."
|
||||
)
|
||||
|
||||
# Create the appropriate request type
|
||||
if is_generation:
|
||||
batch_request = GenerateReqInput(
|
||||
@@ -636,6 +696,7 @@ class TokenizerManagerScoreMixin:
|
||||
text=text_prompts,
|
||||
input_ids=input_ids,
|
||||
positional_embed_overrides=positional_embed_overrides,
|
||||
return_pooled_hidden_states=return_pooled_hidden_states,
|
||||
)
|
||||
|
||||
results = await self.generate_request(batch_request, request).__anext__()
|
||||
@@ -643,12 +704,17 @@ class TokenizerManagerScoreMixin:
|
||||
if use_multi_item_scoring:
|
||||
# Multi-item scoring: extract scores from input_token_ids_logprobs or embedding
|
||||
return self._process_multi_item_scoring_results(
|
||||
results, items, label_token_ids, apply_softmax, batch_request
|
||||
results,
|
||||
items,
|
||||
label_token_ids,
|
||||
apply_softmax,
|
||||
batch_request,
|
||||
return_pooled_hidden_states,
|
||||
)
|
||||
else:
|
||||
# Single-item scoring: process each result separately
|
||||
return self._process_single_item_scoring_results(
|
||||
results, label_token_ids, apply_softmax
|
||||
results, label_token_ids, apply_softmax, return_pooled_hidden_states
|
||||
)
|
||||
|
||||
def _convert_logprobs_to_scores(
|
||||
|
||||
@@ -210,9 +210,8 @@ class BaseTpWorker(ABC):
|
||||
|
||||
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||
logits_output = self.model_runner.forward(forward_batch).logits_output
|
||||
embeddings = logits_output.embeddings
|
||||
return embeddings
|
||||
output = self.model_runner.forward(forward_batch).logits_output
|
||||
return output # Returns EmbeddingPoolerOutput
|
||||
|
||||
|
||||
class TpModelWorker(BaseTpWorker):
|
||||
|
||||
@@ -427,6 +427,9 @@ class ForwardBatch(ForwardBatchDeepSeekMHAMixin):
|
||||
# For hidden states before normal
|
||||
return_hidden_states_before_norm: bool = False
|
||||
|
||||
# Whether to return pooled hidden states (pre-head transformer output)
|
||||
return_pooled_hidden_states: bool = False
|
||||
|
||||
# For hisparse
|
||||
hisparse_coordinator: Optional[HiSparseCoordinator] = None
|
||||
|
||||
@@ -483,6 +486,7 @@ class ForwardBatch(ForwardBatchDeepSeekMHAMixin):
|
||||
tbo_split_seq_index=batch.tbo_split_seq_index,
|
||||
dimensions=batch.dimensions,
|
||||
return_hidden_states_before_norm=batch.return_hidden_states_before_norm,
|
||||
return_pooled_hidden_states=batch.return_pooled_hidden_states,
|
||||
rids=[req.rid for req in batch.reqs],
|
||||
)
|
||||
device = model_runner.device
|
||||
|
||||
@@ -206,6 +206,9 @@ class PiecewiseCudaGraphRunner:
|
||||
|
||||
self.is_multimodal = model_runner.is_multimodal
|
||||
self.mamba_track_enabled = self.is_mamba_track_enabled()
|
||||
# Classification/reward forwards branch on return_pooled_hidden_states; piecewise
|
||||
# CUDA graph capture must use the same flag value as replay for those models.
|
||||
self.capture_return_pooled_hidden_states = not model_runner.is_generation
|
||||
|
||||
# Graph inputs
|
||||
with torch.device(self.device):
|
||||
@@ -390,6 +393,7 @@ class PiecewiseCudaGraphRunner:
|
||||
num_token_non_padded_cpu=num_tokens,
|
||||
global_forward_mode=ForwardMode.EXTEND,
|
||||
lora_ids=None,
|
||||
return_pooled_hidden_states=self.capture_return_pooled_hidden_states,
|
||||
)
|
||||
|
||||
# Attention backend
|
||||
@@ -551,6 +555,7 @@ class PiecewiseCudaGraphRunner:
|
||||
num_token_non_padded_cpu=num_tokens,
|
||||
global_forward_mode=ForwardMode.EXTEND,
|
||||
lora_ids=None,
|
||||
return_pooled_hidden_states=self.capture_return_pooled_hidden_states,
|
||||
)
|
||||
self.tbo_plugin.capture_one_batch_size(forward_batch, num_tokens=num_tokens)
|
||||
|
||||
@@ -748,6 +753,10 @@ class PiecewiseCudaGraphRunner:
|
||||
top_p_normalized_logprobs=forward_batch.top_p_normalized_logprobs,
|
||||
top_p=forward_batch.top_p,
|
||||
dimensions=forward_batch.dimensions,
|
||||
return_pooled_hidden_states=(
|
||||
self.capture_return_pooled_hidden_states
|
||||
or forward_batch.return_pooled_hidden_states
|
||||
),
|
||||
)
|
||||
|
||||
if out_cache_loc_swa is not None:
|
||||
|
||||
@@ -61,7 +61,12 @@ class Gemma2ForSequenceClassification(nn.Module):
|
||||
last_token_hidden = self.pooler(hidden_states, forward_batch).embeddings
|
||||
scores = self.score(last_token_hidden)
|
||||
|
||||
return EmbeddingPoolerOutput(scores)
|
||||
return EmbeddingPoolerOutput(
|
||||
embeddings=scores,
|
||||
pooled_hidden_states=(
|
||||
last_token_hidden if forward_batch.return_pooled_hidden_states else None
|
||||
),
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
Gemma2ForCausalLM.load_weights(self, weights)
|
||||
|
||||
@@ -55,7 +55,12 @@ class InternLM2ForRewardModel(nn.Module):
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
last_token_hidden = self.pooler(hidden_states, forward_batch).embeddings
|
||||
scores = self.v_head(last_token_hidden)
|
||||
return EmbeddingPoolerOutput(scores)
|
||||
return EmbeddingPoolerOutput(
|
||||
embeddings=scores,
|
||||
pooled_hidden_states=(
|
||||
last_token_hidden if forward_batch.return_pooled_hidden_states else None
|
||||
),
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
return InternLM2ForCausalLM.load_weights(self, weights)
|
||||
|
||||
@@ -18,7 +18,12 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import LlamaConfig
|
||||
|
||||
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
|
||||
from sglang.srt.layers.pooler import (
|
||||
EmbeddingPoolerOutput,
|
||||
Pooler,
|
||||
PoolingType,
|
||||
pool_hidden_states,
|
||||
)
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
|
||||
@@ -61,7 +66,12 @@ class LlamaForSequenceClassification(nn.Module):
|
||||
last_token_hidden = self.pooler(hidden_states, forward_batch).embeddings
|
||||
scores = self.score(last_token_hidden)
|
||||
|
||||
return EmbeddingPoolerOutput(scores)
|
||||
return EmbeddingPoolerOutput(
|
||||
embeddings=scores,
|
||||
pooled_hidden_states=(
|
||||
last_token_hidden if forward_batch.return_pooled_hidden_states else None
|
||||
),
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
return LlamaForCausalLM.load_weights(self, weights)
|
||||
@@ -114,7 +124,17 @@ class LlamaForSequenceClassificationWithNormal_Weights(LlamaForSequenceClassific
|
||||
-1, self.num_labels // 2
|
||||
)
|
||||
scores = (rews * pooled_weights).sum(dim=-1).view(-1, 1)
|
||||
return EmbeddingPoolerOutput(scores)
|
||||
|
||||
pooled_hidden = None
|
||||
if forward_batch.return_pooled_hidden_states:
|
||||
pooled_hidden = pool_hidden_states(
|
||||
self.pooler.pooling_type, hidden_states, forward_batch
|
||||
)
|
||||
|
||||
return EmbeddingPoolerOutput(
|
||||
embeddings=scores,
|
||||
pooled_hidden_states=pooled_hidden,
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
return super().load_weights(weights)
|
||||
|
||||
@@ -18,7 +18,12 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import Qwen2Config
|
||||
|
||||
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
|
||||
from sglang.srt.layers.pooler import (
|
||||
EmbeddingPoolerOutput,
|
||||
Pooler,
|
||||
PoolingType,
|
||||
pool_hidden_states,
|
||||
)
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.models.qwen2 import Qwen2ForCausalLM, Qwen2Model
|
||||
@@ -63,7 +68,16 @@ class Qwen2ForRewardModel(nn.Module):
|
||||
logits = self.score(hidden_states)
|
||||
pooled_logits = self.pooler(logits, forward_batch).embeddings
|
||||
|
||||
return EmbeddingPoolerOutput(pooled_logits)
|
||||
pooled_hidden = None
|
||||
if forward_batch.return_pooled_hidden_states:
|
||||
pooled_hidden = pool_hidden_states(
|
||||
self.pooler.pooling_type, hidden_states, forward_batch
|
||||
)
|
||||
|
||||
return EmbeddingPoolerOutput(
|
||||
embeddings=pooled_logits,
|
||||
pooled_hidden_states=pooled_hidden,
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
# Filter out lm_head weights of Qwen2ForCausalLM
|
||||
|
||||
@@ -800,6 +800,9 @@ class ServerArgs:
|
||||
# Handle piecewise CUDA graph.
|
||||
self._handle_piecewise_cuda_graph()
|
||||
|
||||
# Handle multi-item scoring constraints.
|
||||
self._handle_multi_item_scoring()
|
||||
|
||||
# Get GPU memory capacity, which is a common dependency for several configuration steps.
|
||||
gpu_mem = get_device_memory_capacity(self.device)
|
||||
|
||||
@@ -1205,6 +1208,21 @@ class ServerArgs:
|
||||
if self.debug_cuda_graph:
|
||||
self.disable_piecewise_cuda_graph = True
|
||||
|
||||
def _handle_multi_item_scoring(self):
|
||||
"""Disable CUDA graphs when multi-item scoring delimiter is set.
|
||||
|
||||
The padded static input_ids buffer used by CUDA graph replay causes
|
||||
spurious delimiter matches in score_and_pool's MIS path.
|
||||
"""
|
||||
if self.multi_item_scoring_delimiter is None:
|
||||
return
|
||||
if not self.disable_cuda_graph:
|
||||
logger.warning(
|
||||
"CUDA graph is disabled because --multi-item-scoring-delimiter is set."
|
||||
)
|
||||
self.disable_cuda_graph = True
|
||||
self.disable_piecewise_cuda_graph = True
|
||||
|
||||
def _handle_gpu_memory_settings(self, gpu_mem):
|
||||
"""
|
||||
Configure GPU memory-dependent settings including
|
||||
|
||||
Reference in New Issue
Block a user