diff --git a/python/sglang/srt/batch_overlap/two_batch_overlap.py b/python/sglang/srt/batch_overlap/two_batch_overlap.py index 06027edce..4568e3410 100644 --- a/python/sglang/srt/batch_overlap/two_batch_overlap.py +++ b/python/sglang/srt/batch_overlap/two_batch_overlap.py @@ -702,6 +702,7 @@ class TboForwardBatchPreparer: "mrope_positions", # only used by qwen2-vl, thus not care "split_index", # for split prefill "orig_seq_lens", # only used by qwen-1m, thus not care + "return_pooled_hidden_states", ]: output_dict[key] = getattr(batch, key) if not batch.forward_mode.is_target_verify(): diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 67cce0e63..779551626 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -1436,6 +1436,17 @@ def is_piecewise_cuda_graph_disabled_model(model_architectures: List[str]): ) +# SequenceClassification models that use CrossEncodingPooler +_cross_encoding_pooler_archs = [ + "BertForSequenceClassification", + "XLMRobertaForSequenceClassification", +] + + +def is_cross_encoding_pooler_model(model_architectures: List[str]) -> bool: + return any(arch in _cross_encoding_pooler_archs for arch in model_architectures) + + def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: if scale <= 1: return 1.0 diff --git a/python/sglang/srt/entrypoints/engine_score_mixin.py b/python/sglang/srt/entrypoints/engine_score_mixin.py index 089693e80..085e006ce 100644 --- a/python/sglang/srt/entrypoints/engine_score_mixin.py +++ b/python/sglang/srt/entrypoints/engine_score_mixin.py @@ -33,12 +33,10 @@ class EngineScoreMixin: label_token_ids: Optional[List[int]] = None, apply_softmax: bool = False, item_first: bool = False, - # Placeholder token id in query/items that indicates override locations. embed_override_token_id: Optional[int] = None, - # Query embedding overrides. query_embed_overrides: Optional[List[torch.Tensor]] = None, - # Item embedding overrides: per-item list of tensors. item_embed_overrides: Optional[List[Optional[List[torch.Tensor]]]] = None, + return_pooled_hidden_states: bool = False, ) -> ScoreResult: """ Score items against a query using the loaded model. @@ -63,9 +61,13 @@ class EngineScoreMixin: embed_override_token_id: Placeholder token ID used to locate override positions. query_embed_overrides: Embedding vectors replacing placeholder tokens in query. item_embed_overrides: Per-item embedding vectors replacing placeholder tokens in items. + return_pooled_hidden_states: Whether to include raw pooled transformer + hidden states (before the task head) in the result. Only supported + for non-generation models (SequenceClassification, RewardModel). Returns: - ScoreResult with scores (one list per item) and prompt token count. + ScoreResult with scores (one list per item), prompt token count, and + optional pooled_hidden_states tensors. """ return self.loop.run_until_complete( self.tokenizer_manager.score_request( @@ -78,6 +80,7 @@ class EngineScoreMixin: query_embed_overrides=query_embed_overrides, item_embed_overrides=item_embed_overrides, request=None, + return_pooled_hidden_states=return_pooled_hidden_states, ) ) @@ -91,6 +94,7 @@ class EngineScoreMixin: embed_override_token_id: Optional[int] = None, query_embed_overrides: Optional[List[torch.Tensor]] = None, item_embed_overrides: Optional[List[Optional[List[torch.Tensor]]]] = None, + return_pooled_hidden_states: bool = False, ) -> ScoreResult: """Asynchronous version of score(). See score() for full documentation.""" return await self.tokenizer_manager.score_request( @@ -103,4 +107,5 @@ class EngineScoreMixin: query_embed_overrides=query_embed_overrides, item_embed_overrides=item_embed_overrides, request=None, + return_pooled_hidden_states=return_pooled_hidden_states, ) diff --git a/python/sglang/srt/entrypoints/openai/protocol.py b/python/sglang/srt/entrypoints/openai/protocol.py index d92bb8ef6..b1d13d551 100644 --- a/python/sglang/srt/entrypoints/openai/protocol.py +++ b/python/sglang/srt/entrypoints/openai/protocol.py @@ -1016,6 +1016,7 @@ class ScoringRequest(BaseModel): ) apply_softmax: bool = False item_first: bool = False + return_pooled_hidden_states: bool = False model: str = DEFAULT_MODEL_NAME @@ -1023,6 +1024,7 @@ class ScoringResponse(BaseModel): scores: List[ List[float] ] # List of lists of probabilities, each in the order of label_token_ids + pooled_hidden_states: Optional[List[Optional[List[float]]]] = None model: str usage: Optional[UsageInfo] = None object: str = "scoring" diff --git a/python/sglang/srt/entrypoints/openai/serving_score.py b/python/sglang/srt/entrypoints/openai/serving_score.py index ff480b632..9a81716f3 100644 --- a/python/sglang/srt/entrypoints/openai/serving_score.py +++ b/python/sglang/srt/entrypoints/openai/serving_score.py @@ -3,6 +3,7 @@ from typing import Union import torch from fastapi import Request +from fastapi.responses import ORJSONResponse from sglang.srt.entrypoints.openai.protocol import ( ErrorResponse, @@ -76,17 +77,26 @@ class OpenAIServingScore(OpenAIServingBase): query_embed_overrides=query_embed_overrides, item_embed_overrides=item_embed_overrides, request=raw_request, + return_pooled_hidden_states=request.return_pooled_hidden_states, ) + phs_as_lists = None + if result.pooled_hidden_states is not None: + phs_as_lists = [ + t.tolist() if t is not None else None + for t in result.pooled_hidden_states + ] + response = ScoringResponse( scores=result.scores, + pooled_hidden_states=phs_as_lists, model=request.model, usage=UsageInfo( prompt_tokens=result.prompt_tokens, total_tokens=result.prompt_tokens, ), ) - return response + return ORJSONResponse(content=response.model_dump()) except ValueError as e: return self.create_error_response(str(e)) diff --git a/python/sglang/srt/layers/pooler.py b/python/sglang/srt/layers/pooler.py index 7dedeee8a..582bf0a97 100644 --- a/python/sglang/srt/layers/pooler.py +++ b/python/sglang/srt/layers/pooler.py @@ -1,18 +1,22 @@ # adapted from # https://github.com/vllm-project/vllm/blob/82a1b1a82b1fbb454c82a9ef95730b929c9b270c/vllm/model_executor/layers/pooler.py +from __future__ import annotations + from dataclasses import dataclass from enum import IntEnum -from typing import Optional +from typing import TYPE_CHECKING, List, Optional import torch import torch.nn as nn from transformers import PretrainedConfig from sglang.srt.layers.activation import get_cross_encoder_activation_function -from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.server_args import get_global_server_args +if TYPE_CHECKING: + from sglang.srt.model_executor.forward_batch_info import ForwardBatch + class PoolingType(IntEnum): LAST = 0 @@ -21,9 +25,45 @@ class PoolingType(IntEnum): @dataclass class EmbeddingPoolerOutput: + """Output of pooler or score_and_pool. + + Attributes: + embeddings: Pooled embeddings or classification logits. May be a list + of tensors when per-request matryoshka dim truncation produces + different shapes, or when MIS yields a variable number of scores + per request. + pooled_hidden_states: Raw transformer hidden states *before* the + task-specific head, present only when + ``forward_batch.return_pooled_hidden_states`` is True. Tensor + (standard path) or list of tensors (MIS path, one per delimiter). + """ + # Pooler can return list[tensor] instead of tensor if the dimension of each tensor in the batch is different # due to different per-request matryoshka dim truncation embeddings: torch.Tensor | list[torch.Tensor] + pooled_hidden_states: Optional[torch.Tensor | list[torch.Tensor]] = None + + +def pool_hidden_states( + pooling_type: PoolingType, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, +) -> torch.Tensor: + """Pool hidden_states by PoolingType (LAST/CLS). + + Raw pooling only — no normalize, no dim truncation. + Returns shape (batch_size, hidden_size). + """ + if pooling_type == PoolingType.LAST: + last_token_indices = torch.cumsum(forward_batch.extend_seq_lens, dim=0) - 1 + return hidden_states[last_token_indices] + elif pooling_type == PoolingType.CLS: + prompt_lens = forward_batch.extend_seq_lens + first_token_flat_indices = torch.zeros_like(prompt_lens) + first_token_flat_indices[1:] += torch.cumsum(prompt_lens, dim=0)[:-1] + return hidden_states[first_token_flat_indices] + else: + raise ValueError(f"Unsupported pooling type: {pooling_type}") def score_and_pool( @@ -33,16 +73,16 @@ def score_and_pool( forward_batch: ForwardBatch, input_ids: torch.Tensor, ) -> EmbeddingPoolerOutput: - """Apply a classification/score head with multi-item scoring (MIS) support. + """Apply a classification/score head with MIS and pooled-hidden-states support. - When ``multi_item_scoring_delimiter`` is configured and found in - ``input_ids``, takes the MIS path: extract hidden states at the positions - just before each delimiter, apply the score head only to those positions, - then split results per-request using ``forward_batch.extend_seq_lens``. + MIS path (when ``multi_item_scoring_delimiter`` is set and found in ``input_ids``): + extract hidden states at positions just before each delimiter, apply the score head, + then split per-request. - Otherwise, takes the normal single-item path: apply the score head to all - hidden states, then pool (matching the original classification model - forward logic). + Standard path: apply the score head to all hidden states, then pool. + + When ``forward_batch.return_pooled_hidden_states`` is True, the raw pooled + hidden states (before the score head) are included in the output. """ delimiter_token = get_global_server_args().multi_item_scoring_delimiter if delimiter_token is not None and forward_batch.is_prefill_only: @@ -52,26 +92,40 @@ def score_and_pool( if delim_positions.numel() > 0: # Score only the tokens that precede a delimiter - scores = score_head(hidden_states[delim_positions - 1]) + pre_delim_hidden = hidden_states[delim_positions - 1] + scores = score_head(pre_delim_hidden) # Split per-request so the scheduler gets one tensor per request. # Use CPU sequence lengths to avoid per-iteration GPU<->CPU sync # from `.item()` calls on device tensors. seq_lens = forward_batch.extend_seq_lens_cpu start = 0 - per_request = [] + per_request_scores: List[torch.Tensor] = [] + per_request_phs: Optional[List[torch.Tensor]] = ( + [] if forward_batch.return_pooled_hidden_states else None + ) for seq_len in seq_lens: end = start + seq_len mask = (delim_positions >= start) & (delim_positions < end) - per_request.append(scores[mask]) + per_request_scores.append(scores[mask]) + if per_request_phs is not None: + per_request_phs.append(pre_delim_hidden[mask]) start = end - return EmbeddingPoolerOutput(embeddings=per_request) + return EmbeddingPoolerOutput( + embeddings=per_request_scores, + pooled_hidden_states=per_request_phs, + ) - # Standard classification path: score all tokens, then pool. - logits = score_head(hidden_states) - pooled_logits = pooler(logits, forward_batch).embeddings - return EmbeddingPoolerOutput(embeddings=pooled_logits) + # Standard classification path: pool hidden states, then score. + pooled_hs = pool_hidden_states(pooler.pooling_type, hidden_states, forward_batch) + scores = score_head(pooled_hs) + return EmbeddingPoolerOutput( + embeddings=scores, + pooled_hidden_states=( + pooled_hs if forward_batch.return_pooled_hidden_states else None + ), + ) class Pooler(nn.Module): @@ -93,17 +147,9 @@ class Pooler(nn.Module): def forward( self, hidden_states: torch.Tensor, forward_batch: ForwardBatch ) -> EmbeddingPoolerOutput: - - if self.pooling_type == PoolingType.LAST: - last_token_indices = torch.cumsum(forward_batch.extend_seq_lens, dim=0) - 1 - pooled_data = hidden_states[last_token_indices] - elif self.pooling_type == PoolingType.CLS: - prompt_lens = forward_batch.extend_seq_lens - first_token_flat_indices = torch.zeros_like(prompt_lens) - first_token_flat_indices[1:] += torch.cumsum(prompt_lens, dim=0)[:-1] - pooled_data = hidden_states[first_token_flat_indices] - else: - raise ValueError(f"Invalid pooling type: {self.pooling_type}") + pooled_data = pool_hidden_states( + self.pooling_type, hidden_states, forward_batch + ) if forward_batch.dimensions is not None: all_same_dimensions = len(set(forward_batch.dimensions)) == 1 diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index f06b40c4e..94a26966e 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -852,6 +852,9 @@ class EmbeddingReqInput(BaseReq): # The uid of LoRA adaptors, should be initialized by tokenizer manager lora_id: Optional[Union[List[Optional[str]], Optional[str]]] = None + # Whether to return pooled hidden states (pre-head transformer output) + return_pooled_hidden_states: bool = False + def normalize_batch_and_arguments(self): # at least one of text, input_ids, or image should be provided if self.text is None and self.input_ids is None and self.image_data is None: @@ -953,6 +956,7 @@ class EmbeddingReqInput(BaseReq): lora_id=self.lora_id[i] if self.lora_id is not None else None, is_cross_encoder_request=True, http_worker_ipc=self.http_worker_ipc, + return_pooled_hidden_states=self.return_pooled_hidden_states, ) else: sub = EmbeddingReqInput( @@ -976,6 +980,7 @@ class EmbeddingReqInput(BaseReq): dimensions=self.dimensions, http_worker_ipc=self.http_worker_ipc, received_time=self.received_time, + return_pooled_hidden_states=self.return_pooled_hidden_states, ) cache[i] = sub return sub @@ -1007,6 +1012,9 @@ class TokenizedEmbeddingReqInput(BaseReq): # For observability time_stats: Optional[Union[APIServerReqTimeStats, DPControllerReqTimeStats]] = None + # Whether to return pooled hidden states (pre-head transformer output) + return_pooled_hidden_states: bool = False + @dataclass class BatchTokenizedEmbeddingReqInput(BaseBatchReq): @@ -1175,6 +1183,12 @@ class BatchEmbeddingOutput(BaseBatchReq): # For observability time_stats: Optional[List[SchedulerReqTimeStats]] = None + # Optional pooled hidden states (pre-head transformer output). + # Sent as a single stacked tensor to minimize pickle overhead. + pooled_hidden_states: Optional[ + Union[List[Optional[torch.Tensor]], torch.Tensor] + ] = None + @dataclass class ClearHiCacheReqInput(BaseReq): diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index a6172890f..faba71368 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -596,6 +596,7 @@ class Req(ReqDllmMixin): time_stats: Optional[ Union[APIServerReqTimeStats, DPControllerReqTimeStats] ] = None, + return_pooled_hidden_states: bool = False, ): # Input and output info self.rid = rid @@ -872,6 +873,10 @@ class Req(ReqDllmMixin): # For Matryoshka embeddings self.dimensions = dimensions + # Whether to return pooled hidden states (pre-head transformer output) + self.return_pooled_hidden_states = return_pooled_hidden_states + self.pooled_hidden_state = None + # For diffusion LLM self.init_diffusion_llm(dllm_config) @@ -1403,6 +1408,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): # For matryoshka embeddings dimensions: Optional[list[int]] = None + # Whether to return pooled hidden states (pre-head transformer output) + return_pooled_hidden_states: bool = False + # For split prefill split_index: int = 0 split_prefill_finished: bool = False @@ -1594,6 +1602,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): for r in reqs ] + # OR across the batch so ForwardBatch matches a single fused forward; requests + # that did not ask for PHS still skip attaching it in the output processor. + self.return_pooled_hidden_states = any( + r.return_pooled_hidden_states for r in reqs + ) + token_type_ids = [ r.token_type_ids for r in reqs if r.token_type_ids is not None ] @@ -2439,6 +2453,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): extend_input_logprob_token_ids=self.extend_input_logprob_token_ids, is_prefill_only=self.is_prefill_only, dimensions=self.dimensions, + return_pooled_hidden_states=self.return_pooled_hidden_states, dllm_block_offsets=[req.dllm_block_offset for req in self.reqs], dllm_config=self.dllm_config, reqs=self.reqs, @@ -2632,6 +2647,9 @@ class ModelWorkerBatch: # For matryoshka embeddings dimensions: Optional[list[int]] = None + # Whether to return pooled hidden states (pre-head transformer output) + return_pooled_hidden_states: bool = False + # Whether this batch is prefill-only (no token generation needed) is_prefill_only: bool = False diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index dd67a0d38..16193d132 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -254,12 +254,22 @@ _is_npu = is_npu() @dataclass class EmbeddingBatchResult: + """Result from an embedding/classification forward pass. + + Attributes: + embeddings: Model output — pooled embeddings or classification logits. + pooled_hidden_states: Raw hidden states before the task head. Present + only when the batch contained ``return_pooled_hidden_states=True`` + requests. Tensor (uniform shapes) or list of tensors (MIS). + copy_done: CUDA event recorded after the async CPU copy completes. + """ + embeddings: torch.Tensor + pooled_hidden_states: Optional[torch.Tensor] = None copy_done: Optional[torch.cuda.Event] = None def copy_to_cpu(self): - """Copy embeddings tensor to CPU in overlap scheduling.""" - + """Copy embeddings and pooled hidden states to CPU for overlap scheduling.""" if isinstance(self.embeddings, torch.Tensor): self.copy_done = torch.get_device_module(self.embeddings.device).Event() self.embeddings = self.embeddings.to("cpu", non_blocking=True) @@ -273,6 +283,16 @@ class EmbeddingBatchResult: emb.to("cpu", non_blocking=True) for emb in self.embeddings ] + if self.pooled_hidden_states is not None: + if isinstance(self.pooled_hidden_states, list): + self.pooled_hidden_states = [ + t.to("cpu", non_blocking=True) for t in self.pooled_hidden_states + ] + else: + self.pooled_hidden_states = self.pooled_hidden_states.to( + "cpu", non_blocking=True + ) + self.copy_done.record() @@ -2157,6 +2177,7 @@ class Scheduler( lora_id=recv_req.lora_id, http_worker_ipc=recv_req.http_worker_ipc, time_stats=recv_req.time_stats, + return_pooled_hidden_states=recv_req.return_pooled_hidden_states, ) req.tokenizer = self.tokenizer @@ -2828,14 +2849,22 @@ class Scheduler( self.record_batch_in_overlap(model_worker_batch) with self.forward_stream_ctx, self.record_bubble_metrics(batch): self.forward_stream.wait_stream(self.schedule_stream) - embeddings = self.tp_worker.forward_batch_embedding( + pooler_output = self.tp_worker.forward_batch_embedding( model_worker_batch ) - ret = EmbeddingBatchResult(embeddings=embeddings) + ret = EmbeddingBatchResult( + embeddings=pooler_output.embeddings, + pooled_hidden_states=pooler_output.pooled_hidden_states, + ) ret.copy_to_cpu() else: - embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch) - ret = EmbeddingBatchResult(embeddings=embeddings) + pooler_output = self.tp_worker.forward_batch_embedding( + model_worker_batch + ) + ret = EmbeddingBatchResult( + embeddings=pooler_output.embeddings, + pooled_hidden_states=pooler_output.pooled_hidden_states, + ) # Capture prefill end time for EXTEND mode if batch.forward_mode == ForwardMode.EXTEND: diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index c8172f603..fc8f4855b 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -286,6 +286,7 @@ class SchedulerOutputProcessorMixin: is_sparse = envs.SGLANG_EMBEDDINGS_SPARSE_HEAD.is_set() embeddings = result.embeddings + phs = result.pooled_hidden_states if is_sparse: batch_ids, token_ids = embeddings.indices() @@ -302,12 +303,20 @@ class SchedulerOutputProcessorMixin: else: embeddings = [tensor.tolist() for tensor in embeddings] + if phs is not None: + if isinstance(phs, list): + phs = [t.cpu().detach() for t in phs] + else: + phs = phs.cpu().detach() + # Check finish conditions for i, req in enumerate(batch.reqs): if req.is_retracted: continue req.embedding = embeddings[i] + if req.return_pooled_hidden_states and phs is not None: + req.pooled_hidden_state = phs[i] if req.is_chunked <= 0: req.time_stats.set_prefill_finished_time() # Dummy output token for embedding models @@ -1215,6 +1224,8 @@ class SchedulerOutputProcessorMixin: cached_tokens_details = [] # Detailed breakdown by cache source time_stats = [] retraction_counts = [] + phs_list = [] + has_phs = False for req in reqs: if req.finished(): rids.append(req.rid) @@ -1228,6 +1239,28 @@ class SchedulerOutputProcessorMixin: cached_tokens_details.append(self._get_cached_tokens_details(req)) time_stats.append(req.time_stats) retraction_counts.append(req.retraction_count) + + phs = req.pooled_hidden_state + phs_list.append(phs) + if phs is not None: + has_phs = True + + # Optimize PHS for pickle: torch.stack reduces N __reduce_ex__ + # calls to 1 across the ZMQ IPC boundary. We can only stack when + # *every* entry is non-None (homogeneous batch); mixed batches + # (some requests want PHS, others don't) keep the raw list so + # positional indexing on the receiver side stays correct. + stacked_phs = None + if has_phs: + all_have_phs = all(t is not None for t in phs_list) + if all_have_phs: + if all(t.shape == phs_list[0].shape for t in phs_list): + stacked_phs = torch.stack(phs_list) + else: + stacked_phs = phs_list + else: + stacked_phs = phs_list + self.send_to_detokenizer.send_output( BatchEmbeddingOutput( rids=rids, @@ -1241,5 +1274,6 @@ class SchedulerOutputProcessorMixin: placeholder_tokens_idx=None, placeholder_tokens_val=None, retraction_counts=retraction_counts, + pooled_hidden_states=stacked_phs, ) ) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 09a9df885..3968411ee 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -1032,6 +1032,7 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerScoreMixin): dimensions=obj.dimensions, lora_id=obj.lora_id, http_worker_ipc=obj.http_worker_ipc, + return_pooled_hidden_states=obj.return_pooled_hidden_states, ) tokenized_obj.time_stats = self.rid_to_state[obj.rid].time_stats @@ -1776,6 +1777,11 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerScoreMixin): "embedding": recv_obj.embeddings[i], "meta_info": meta_info, } + if ( + recv_obj.pooled_hidden_states is not None + and recv_obj.pooled_hidden_states[i] is not None + ): + out_dict["pooled_hidden_state"] = recv_obj.pooled_hidden_states[i] # Set first_token_time on the first output batch. # This is the single write point for first_token_time. diff --git a/python/sglang/srt/managers/tokenizer_manager_score_mixin.py b/python/sglang/srt/managers/tokenizer_manager_score_mixin.py index e6180430e..d520cf782 100644 --- a/python/sglang/srt/managers/tokenizer_manager_score_mixin.py +++ b/python/sglang/srt/managers/tokenizer_manager_score_mixin.py @@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch +from sglang.srt.configs.model_config import is_cross_encoding_pooler_model from sglang.srt.managers.embed_types import PositionalEmbeds from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput @@ -15,6 +16,12 @@ logger = logging.getLogger(__name__) class ScoreResult: scores: List[List[float]] prompt_tokens: int = 0 + # Per-item pooled hidden states (pre-head transformer output). + # CPU tensors when return_pooled_hidden_states=True; kept as tensors so + # in-process consumers (gRPC, engine API) avoid a .tolist() round-trip. + # The HTTP path converts to lists in serving_score.py before JSON serialization. + # Same layout as scores: one tensor per item (not a single packed 2D tensor). + pooled_hidden_states: Optional[List[Optional[torch.Tensor]]] = None class TokenizerManagerScoreMixin: @@ -151,6 +158,7 @@ class TokenizerManagerScoreMixin: label_token_ids: Optional[List[int]], apply_softmax: bool, batch_request=None, + return_pooled_hidden_states: bool = False, ) -> ScoreResult: """ Process results from multi-item scoring request. @@ -166,11 +174,13 @@ class TokenizerManagerScoreMixin: label_token_ids: Token IDs to extract scores for apply_softmax: Whether to apply softmax normalization batch_request: The original batch request containing input sequence + return_pooled_hidden_states: Whether to extract pooled hidden states + from the result and include them in the ScoreResult. Returns: - ScoreResult with: - scores: List of score lists, one for each prompt, each in the order of label_token_ids. - prompt_tokens: The number of prompt tokens processed. + ScoreResult with per-item scores, prompt token count, and optional + pooled_hidden_states (when return_pooled_hidden_states=True and the + model populated the field). """ single_result = results[0] if isinstance(results, list) else results meta_info = single_result.get("meta_info", {}) @@ -225,10 +235,24 @@ class TokenizerManagerScoreMixin: # Skip the first delimiter (query-item boundary) scores = per_delimiter_scores[1:] - return ScoreResult(scores=scores, prompt_tokens=prompt_tokens) + phs_list = None + if return_pooled_hidden_states: + raw_phs = single_result.get("pooled_hidden_state") + if raw_phs is not None and len(raw_phs) == expected_count: + phs_list = raw_phs[1:] + + return ScoreResult( + scores=scores, + prompt_tokens=prompt_tokens, + pooled_hidden_states=phs_list, + ) def _process_single_item_scoring_results( - self, results: Any, label_token_ids: Optional[List[int]], apply_softmax: bool + self, + results: Any, + label_token_ids: Optional[List[int]], + apply_softmax: bool, + return_pooled_hidden_states: bool = False, ) -> ScoreResult: """ Process results from single-item scoring request. @@ -241,13 +265,14 @@ class TokenizerManagerScoreMixin: results: Results from generate_request label_token_ids: Token IDs to extract scores for (generation models only) apply_softmax: Whether to apply softmax normalization + return_pooled_hidden_states: Whether to extract pooled hidden states Returns: - ScoreResult with: - scores: List of score lists, one for each prompt, each in the order of label_token_ids. - prompt_tokens: The number of prompt tokens processed. + ScoreResult with per-item scores, prompt token count, and optional pooled_hidden_states. """ scores = [] + phs_list = [] + has_phs = False prompt_tokens = 0 is_generation = getattr(self, "is_generation", True) @@ -293,7 +318,17 @@ class TokenizerManagerScoreMixin: # EmbeddingPoolerOutput API. scores.append(embedding) - return ScoreResult(scores=scores, prompt_tokens=prompt_tokens) + if return_pooled_hidden_states: + phs = result.get("pooled_hidden_state") + phs_list.append(phs) + if phs is not None: + has_phs = True + + return ScoreResult( + scores=scores, + prompt_tokens=prompt_tokens, + pooled_hidden_states=phs_list if has_phs else None, + ) # ------------------------------------------------------------------ # Embed override position resolution @@ -481,6 +516,7 @@ class TokenizerManagerScoreMixin: query_embed_overrides: Optional[List[torch.Tensor]] = None, item_embed_overrides: Optional[List[Optional[List[torch.Tensor]]]] = None, request: Optional[Any] = None, + return_pooled_hidden_states: bool = False, ) -> ScoreResult: """ Score the probability of specified token IDs appearing after the given (query + item) pair. @@ -510,11 +546,18 @@ class TokenizerManagerScoreMixin: query_embed_overrides: Embedding vectors replacing placeholder tokens in query. item_embed_overrides: Per-item embedding vectors replacing placeholder tokens in items. request: Optional FastAPI request object + return_pooled_hidden_states: Whether to include the raw pooled transformer + hidden states (before the task-specific head) in the result. Only + supported for non-generation models (SequenceClassification, + RewardModel). Raises ValueError for CausalLM models. Returns: ScoreResult with: - scores: List of score lists, one for each prompt, each in the order of label_token_ids. + scores: List of score lists, one per item. prompt_tokens: The number of prompt tokens processed. + pooled_hidden_states: Per-item CPU tensors when + return_pooled_hidden_states=True and the model supports it; + None otherwise. """ is_generation = getattr(self, "is_generation", True) @@ -618,6 +661,23 @@ class TokenizerManagerScoreMixin: "Invalid combination of query/items types for score_request." ) + if return_pooled_hidden_states: + if is_generation: + raise ValueError( + "return_pooled_hidden_states is not supported for CausalLM models. " + "It requires a model with a task-specific head " + "(e.g. SequenceClassification or RewardModel)." + ) + model_config = getattr(self, "model_config", None) + if model_config is not None: + archs = getattr(model_config.hf_config, "architectures", []) or [] + if is_cross_encoding_pooler_model(archs): + raise ValueError( + f"return_pooled_hidden_states is not supported for " + f"{archs[0]}. This model uses CrossEncodingPooler which " + f"does not expose pre-head hidden states." + ) + # Create the appropriate request type if is_generation: batch_request = GenerateReqInput( @@ -636,6 +696,7 @@ class TokenizerManagerScoreMixin: text=text_prompts, input_ids=input_ids, positional_embed_overrides=positional_embed_overrides, + return_pooled_hidden_states=return_pooled_hidden_states, ) results = await self.generate_request(batch_request, request).__anext__() @@ -643,12 +704,17 @@ class TokenizerManagerScoreMixin: if use_multi_item_scoring: # Multi-item scoring: extract scores from input_token_ids_logprobs or embedding return self._process_multi_item_scoring_results( - results, items, label_token_ids, apply_softmax, batch_request + results, + items, + label_token_ids, + apply_softmax, + batch_request, + return_pooled_hidden_states, ) else: # Single-item scoring: process each result separately return self._process_single_item_scoring_results( - results, label_token_ids, apply_softmax + results, label_token_ids, apply_softmax, return_pooled_hidden_states ) def _convert_logprobs_to_scores( diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 226139dbb..4a01868d5 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -210,9 +210,8 @@ class BaseTpWorker(ABC): def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch): forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) - logits_output = self.model_runner.forward(forward_batch).logits_output - embeddings = logits_output.embeddings - return embeddings + output = self.model_runner.forward(forward_batch).logits_output + return output # Returns EmbeddingPoolerOutput class TpModelWorker(BaseTpWorker): diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 831b3b6a0..7e32a8356 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -427,6 +427,9 @@ class ForwardBatch(ForwardBatchDeepSeekMHAMixin): # For hidden states before normal return_hidden_states_before_norm: bool = False + # Whether to return pooled hidden states (pre-head transformer output) + return_pooled_hidden_states: bool = False + # For hisparse hisparse_coordinator: Optional[HiSparseCoordinator] = None @@ -483,6 +486,7 @@ class ForwardBatch(ForwardBatchDeepSeekMHAMixin): tbo_split_seq_index=batch.tbo_split_seq_index, dimensions=batch.dimensions, return_hidden_states_before_norm=batch.return_hidden_states_before_norm, + return_pooled_hidden_states=batch.return_pooled_hidden_states, rids=[req.rid for req in batch.reqs], ) device = model_runner.device diff --git a/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py b/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py index 932a15e71..ded66a366 100644 --- a/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py @@ -206,6 +206,9 @@ class PiecewiseCudaGraphRunner: self.is_multimodal = model_runner.is_multimodal self.mamba_track_enabled = self.is_mamba_track_enabled() + # Classification/reward forwards branch on return_pooled_hidden_states; piecewise + # CUDA graph capture must use the same flag value as replay for those models. + self.capture_return_pooled_hidden_states = not model_runner.is_generation # Graph inputs with torch.device(self.device): @@ -390,6 +393,7 @@ class PiecewiseCudaGraphRunner: num_token_non_padded_cpu=num_tokens, global_forward_mode=ForwardMode.EXTEND, lora_ids=None, + return_pooled_hidden_states=self.capture_return_pooled_hidden_states, ) # Attention backend @@ -551,6 +555,7 @@ class PiecewiseCudaGraphRunner: num_token_non_padded_cpu=num_tokens, global_forward_mode=ForwardMode.EXTEND, lora_ids=None, + return_pooled_hidden_states=self.capture_return_pooled_hidden_states, ) self.tbo_plugin.capture_one_batch_size(forward_batch, num_tokens=num_tokens) @@ -748,6 +753,10 @@ class PiecewiseCudaGraphRunner: top_p_normalized_logprobs=forward_batch.top_p_normalized_logprobs, top_p=forward_batch.top_p, dimensions=forward_batch.dimensions, + return_pooled_hidden_states=( + self.capture_return_pooled_hidden_states + or forward_batch.return_pooled_hidden_states + ), ) if out_cache_loc_swa is not None: diff --git a/python/sglang/srt/models/gemma2_reward.py b/python/sglang/srt/models/gemma2_reward.py index 03bea4d10..8c8eda22b 100644 --- a/python/sglang/srt/models/gemma2_reward.py +++ b/python/sglang/srt/models/gemma2_reward.py @@ -61,7 +61,12 @@ class Gemma2ForSequenceClassification(nn.Module): last_token_hidden = self.pooler(hidden_states, forward_batch).embeddings scores = self.score(last_token_hidden) - return EmbeddingPoolerOutput(scores) + return EmbeddingPoolerOutput( + embeddings=scores, + pooled_hidden_states=( + last_token_hidden if forward_batch.return_pooled_hidden_states else None + ), + ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): Gemma2ForCausalLM.load_weights(self, weights) diff --git a/python/sglang/srt/models/internlm2_reward.py b/python/sglang/srt/models/internlm2_reward.py index 68be8d001..6fb4587fe 100644 --- a/python/sglang/srt/models/internlm2_reward.py +++ b/python/sglang/srt/models/internlm2_reward.py @@ -55,7 +55,12 @@ class InternLM2ForRewardModel(nn.Module): hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) last_token_hidden = self.pooler(hidden_states, forward_batch).embeddings scores = self.v_head(last_token_hidden) - return EmbeddingPoolerOutput(scores) + return EmbeddingPoolerOutput( + embeddings=scores, + pooled_hidden_states=( + last_token_hidden if forward_batch.return_pooled_hidden_states else None + ), + ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): return InternLM2ForCausalLM.load_weights(self, weights) diff --git a/python/sglang/srt/models/llama_reward.py b/python/sglang/srt/models/llama_reward.py index 2f78dfa1b..a8e0f04e6 100644 --- a/python/sglang/srt/models/llama_reward.py +++ b/python/sglang/srt/models/llama_reward.py @@ -18,7 +18,12 @@ import torch from torch import nn from transformers import LlamaConfig -from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType +from sglang.srt.layers.pooler import ( + EmbeddingPoolerOutput, + Pooler, + PoolingType, + pool_hidden_states, +) from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel @@ -61,7 +66,12 @@ class LlamaForSequenceClassification(nn.Module): last_token_hidden = self.pooler(hidden_states, forward_batch).embeddings scores = self.score(last_token_hidden) - return EmbeddingPoolerOutput(scores) + return EmbeddingPoolerOutput( + embeddings=scores, + pooled_hidden_states=( + last_token_hidden if forward_batch.return_pooled_hidden_states else None + ), + ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): return LlamaForCausalLM.load_weights(self, weights) @@ -114,7 +124,17 @@ class LlamaForSequenceClassificationWithNormal_Weights(LlamaForSequenceClassific -1, self.num_labels // 2 ) scores = (rews * pooled_weights).sum(dim=-1).view(-1, 1) - return EmbeddingPoolerOutput(scores) + + pooled_hidden = None + if forward_batch.return_pooled_hidden_states: + pooled_hidden = pool_hidden_states( + self.pooler.pooling_type, hidden_states, forward_batch + ) + + return EmbeddingPoolerOutput( + embeddings=scores, + pooled_hidden_states=pooled_hidden, + ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): return super().load_weights(weights) diff --git a/python/sglang/srt/models/qwen2_rm.py b/python/sglang/srt/models/qwen2_rm.py index f5ed9eae2..aedebd178 100644 --- a/python/sglang/srt/models/qwen2_rm.py +++ b/python/sglang/srt/models/qwen2_rm.py @@ -18,7 +18,12 @@ import torch from torch import nn from transformers import Qwen2Config -from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType +from sglang.srt.layers.pooler import ( + EmbeddingPoolerOutput, + Pooler, + PoolingType, + pool_hidden_states, +) from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.models.qwen2 import Qwen2ForCausalLM, Qwen2Model @@ -63,7 +68,16 @@ class Qwen2ForRewardModel(nn.Module): logits = self.score(hidden_states) pooled_logits = self.pooler(logits, forward_batch).embeddings - return EmbeddingPoolerOutput(pooled_logits) + pooled_hidden = None + if forward_batch.return_pooled_hidden_states: + pooled_hidden = pool_hidden_states( + self.pooler.pooling_type, hidden_states, forward_batch + ) + + return EmbeddingPoolerOutput( + embeddings=pooled_logits, + pooled_hidden_states=pooled_hidden, + ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Filter out lm_head weights of Qwen2ForCausalLM diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 0c289beff..0bc5db5d9 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -800,6 +800,9 @@ class ServerArgs: # Handle piecewise CUDA graph. self._handle_piecewise_cuda_graph() + # Handle multi-item scoring constraints. + self._handle_multi_item_scoring() + # Get GPU memory capacity, which is a common dependency for several configuration steps. gpu_mem = get_device_memory_capacity(self.device) @@ -1205,6 +1208,21 @@ class ServerArgs: if self.debug_cuda_graph: self.disable_piecewise_cuda_graph = True + def _handle_multi_item_scoring(self): + """Disable CUDA graphs when multi-item scoring delimiter is set. + + The padded static input_ids buffer used by CUDA graph replay causes + spurious delimiter matches in score_and_pool's MIS path. + """ + if self.multi_item_scoring_delimiter is None: + return + if not self.disable_cuda_graph: + logger.warning( + "CUDA graph is disabled because --multi-item-scoring-delimiter is set." + ) + self.disable_cuda_graph = True + self.disable_piecewise_cuda_graph = True + def _handle_gpu_memory_settings(self, gpu_mem): """ Configure GPU memory-dependent settings including diff --git a/test/registered/prefill_only/test_pooled_hidden_states.py b/test/registered/prefill_only/test_pooled_hidden_states.py new file mode 100644 index 000000000..66582a773 --- /dev/null +++ b/test/registered/prefill_only/test_pooled_hidden_states.py @@ -0,0 +1,427 @@ +"""Tests for the return_pooled_hidden_states feature on the scoring API. + +Covers both Engine-level (Python API) and HTTP-level (/v1/score) integration: + + TestPooledHiddenStatesEngine — SeqCls model, single-item scoring + TestPooledHiddenStatesMISEngine — SeqCls model, MIS delimiter mode + TestPooledHiddenStatesHTTP — HTTP layer serialization round-trip + TestPooledHiddenStatesCausalLMRejection — CausalLM must reject the flag + +Each test class spins up its own Engine or server so GPU memory is isolated. +""" + +import json +import unittest + +import requests +import torch + +from sglang.srt.entrypoints.engine import Engine +from sglang.srt.utils import kill_process_tree +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.test_utils import ( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + +register_cuda_ci(est_time=240, suite="stage-b-test-1-gpu-small") + +_SEQCLS_MODEL = "Qwen/Qwen3-0.6B" +_QWEN3_EOT_TOKEN_ID = 151643 +_CAUSAL_LM_MODEL = DEFAULT_SMALL_MODEL_NAME_FOR_TEST +_NUM_LABELS = 4 + +# Local overrides for offline testing (no network). Set to None to use HF hub. +_LOCAL_SEQCLS_MODEL = ( + "/shared/public/elr-models/Qwen/Qwen3-0.6B/e6de91484c29aa9480d55605af694f39b081c455" +) +_LOCAL_CAUSAL_LM_MODEL = "/shared/public/elr-models/meta-llama/Llama-3.2-1B-Instruct/e9f8effbab1cbdc515c11ee6e098e3d5a9f51e14" + +import os + +if _LOCAL_SEQCLS_MODEL and os.path.isdir(_LOCAL_SEQCLS_MODEL): + _SEQCLS_MODEL = _LOCAL_SEQCLS_MODEL +if _LOCAL_CAUSAL_LM_MODEL and os.path.isdir(_LOCAL_CAUSAL_LM_MODEL): + _CAUSAL_LM_MODEL = _LOCAL_CAUSAL_LM_MODEL + + +# --------------------------------------------------------------------------- +# Engine — single-item scoring (no MIS) +# --------------------------------------------------------------------------- + + +class TestPooledHiddenStatesEngine(CustomTestCase): + """Validates return_pooled_hidden_states through the Engine Python API. + + Uses Qwen3ForSequenceClassification with a random head so we only care + about shape and pipeline plumbing, not numerical accuracy. + """ + + @classmethod + def setUpClass(cls): + cls.engine = Engine( + model_path=_SEQCLS_MODEL, + disable_radix_cache=True, + json_model_override_args=json.dumps( + { + "architectures": ["Qwen3ForSequenceClassification"], + "num_labels": _NUM_LABELS, + } + ), + mem_fraction_static=0.15, + ) + + @classmethod + def tearDownClass(cls): + if hasattr(cls, "engine") and cls.engine: + cls.engine.shutdown() + torch.cuda.empty_cache() + + def test_phs_returned_when_requested(self): + """Pooled hidden states are present and shaped correctly.""" + result = self.engine.score( + query="Rate each:", + items=["Good", "Bad"], + return_pooled_hidden_states=True, + ) + self.assertIsNotNone(result.pooled_hidden_states) + self.assertEqual(len(result.pooled_hidden_states), 2) + for phs in result.pooled_hidden_states: + self.assertIsInstance(phs, torch.Tensor) + self.assertEqual(phs.dim(), 1) + self.assertGreater(phs.shape[0], 0) + + def test_phs_none_when_not_requested(self): + """Without the flag, pooled_hidden_states must be None.""" + result = self.engine.score( + query="Rate each:", + items=["Good", "Bad"], + return_pooled_hidden_states=False, + ) + self.assertIsNone(result.pooled_hidden_states) + + def test_phs_shape_is_consistent(self): + """PHS tensors for different items share the same hidden dimension.""" + result = self.engine.score( + query="Evaluate:", + items=["Alpha", "Beta", "Gamma"], + return_pooled_hidden_states=True, + ) + self.assertIsNotNone(result.pooled_hidden_states) + dims = {phs.shape[0] for phs in result.pooled_hidden_states} + self.assertEqual(len(dims), 1, "All PHS vectors must share the same hidden dim") + self.assertGreater(dims.pop(), 0) + + def test_phs_count_matches_items(self): + """Number of PHS tensors equals number of items for various batch sizes.""" + for n in [1, 3, 5]: + with self.subTest(n=n): + result = self.engine.score( + query="Classify:", + items=[f"Item {i}" for i in range(n)], + return_pooled_hidden_states=True, + ) + self.assertIsNotNone(result.pooled_hidden_states) + self.assertEqual(len(result.pooled_hidden_states), n) + + def test_phs_on_cpu(self): + """Returned tensors live on CPU (no GPU references leak to caller).""" + result = self.engine.score( + query="Check device:", + items=["Test"], + return_pooled_hidden_states=True, + ) + for phs in result.pooled_hidden_states: + self.assertEqual(str(phs.device), "cpu") + + def test_phs_deterministic(self): + """Identical requests produce identical PHS tensors.""" + kwargs = dict( + query="Evaluate:", items=["A", "B"], return_pooled_hidden_states=True + ) + phs1 = self.engine.score(**kwargs).pooled_hidden_states + phs2 = self.engine.score(**kwargs).pooled_hidden_states + for t1, t2 in zip(phs1, phs2): + self.assertTrue( + torch.allclose(t1, t2, atol=1e-5), + "Pooled hidden states differ across identical requests", + ) + + def test_scores_unaffected_by_phs_flag(self): + """The phs flag must not change the scores themselves (fp16 tolerance).""" + kwargs = dict(query="Rate:", items=["X", "Y", "Z"], apply_softmax=True) + scores_without = self.engine.score( + **kwargs, return_pooled_hidden_states=False + ).scores + scores_with = self.engine.score( + **kwargs, return_pooled_hidden_states=True + ).scores + self.assertEqual(len(scores_without), len(scores_with)) + for row_a, row_b in zip(scores_without, scores_with): + for a, b in zip(row_a, row_b): + self.assertAlmostEqual(a, b, places=2) + + def test_phs_with_tokenized_inputs(self): + """Pre-tokenized inputs also return PHS correctly.""" + from transformers import AutoTokenizer + + tok = AutoTokenizer.from_pretrained(_SEQCLS_MODEL) + query, items = "Evaluate:", ["Alpha", "Beta"] + result = self.engine.score( + query=tok.encode(query), + items=[tok.encode(i) for i in items], + return_pooled_hidden_states=True, + ) + self.assertIsNotNone(result.pooled_hidden_states) + self.assertEqual(len(result.pooled_hidden_states), 2) + + +# --------------------------------------------------------------------------- +# Engine — MIS delimiter mode +# --------------------------------------------------------------------------- + + +class TestPooledHiddenStatesMISEngine(CustomTestCase): + """Validates return_pooled_hidden_states in MIS (delimiter) scoring mode. + + MIS packs all items into one sequence; the PHS at each delimiter position + should be returned per-item. + """ + + @classmethod + def setUpClass(cls): + cls.engine = Engine( + model_path=_SEQCLS_MODEL, + disable_radix_cache=True, + chunked_prefill_size=-1, + multi_item_scoring_delimiter=_QWEN3_EOT_TOKEN_ID, + json_model_override_args=json.dumps( + { + "architectures": ["Qwen3ForSequenceClassification"], + "num_labels": _NUM_LABELS, + } + ), + mem_fraction_static=0.15, + ) + + @classmethod + def tearDownClass(cls): + if hasattr(cls, "engine") and cls.engine: + cls.engine.shutdown() + torch.cuda.empty_cache() + + def test_mis_phs_count_matches_items(self): + """MIS must return one PHS tensor per item.""" + items = ["Option A", "Option B", "Option C"] + result = self.engine.score( + query="Rate each:", items=items, return_pooled_hidden_states=True + ) + self.assertIsNotNone(result.pooled_hidden_states) + self.assertEqual(len(result.pooled_hidden_states), len(items)) + + def test_mis_phs_none_when_not_requested(self): + result = self.engine.score( + query="Rate each:", + items=["A", "B"], + return_pooled_hidden_states=False, + ) + self.assertIsNone(result.pooled_hidden_states) + + def test_mis_phs_are_tensors_on_cpu(self): + result = self.engine.score( + query="Classify:", items=["X", "Y"], return_pooled_hidden_states=True + ) + for phs in result.pooled_hidden_states: + self.assertIsInstance(phs, torch.Tensor) + self.assertEqual(str(phs.device), "cpu") + + def test_mis_phs_different_items_different_hidden_states(self): + """Different items should produce distinct PHS vectors.""" + items = [ + "Option A is about cats", + "Option B is about dogs", + "Option C is about fish", + ] + result = self.engine.score( + query="Classify:", items=items, return_pooled_hidden_states=True + ) + phs = result.pooled_hidden_states + self.assertFalse( + all(torch.allclose(phs[0], p, atol=1e-6) for p in phs[1:]), + "All MIS items returned identical hidden states", + ) + + def test_mis_single_item(self): + """Single item through MIS path still returns one PHS tensor.""" + result = self.engine.score( + query="Evaluate:", items=["Only one"], return_pooled_hidden_states=True + ) + self.assertIsNotNone(result.pooled_hidden_states) + self.assertEqual(len(result.pooled_hidden_states), 1) + + def test_mis_many_items(self): + """10 items all produce PHS tensors of consistent shape.""" + items = [f"Item {i}" for i in range(10)] + result = self.engine.score( + query="Classify:", items=items, return_pooled_hidden_states=True + ) + self.assertIsNotNone(result.pooled_hidden_states) + self.assertEqual(len(result.pooled_hidden_states), len(items)) + shapes = {phs.shape for phs in result.pooled_hidden_states} + self.assertEqual(len(shapes), 1, "MIS PHS shapes should be uniform") + + def test_mis_scores_unaffected_by_phs_flag(self): + """Enabling PHS does not alter the returned scores (fp16 tolerance).""" + kwargs = dict( + query="Rate:", items=["Alpha", "Beta", "Gamma"], apply_softmax=True + ) + scores_without = self.engine.score( + **kwargs, return_pooled_hidden_states=False + ).scores + scores_with = self.engine.score( + **kwargs, return_pooled_hidden_states=True + ).scores + for row_a, row_b in zip(scores_without, scores_with): + for a, b in zip(row_a, row_b): + self.assertAlmostEqual(a, b, places=2) + + +# --------------------------------------------------------------------------- +# CausalLM rejection +# --------------------------------------------------------------------------- + + +class TestPooledHiddenStatesCausalLMRejection(CustomTestCase): + """CausalLM models must reject return_pooled_hidden_states=True.""" + + @classmethod + def setUpClass(cls): + cls.engine = Engine(model_path=_CAUSAL_LM_MODEL) + + @classmethod + def tearDownClass(cls): + if hasattr(cls, "engine") and cls.engine: + cls.engine.shutdown() + torch.cuda.empty_cache() + + def test_causal_lm_rejects_phs(self): + """ValueError raised when requesting PHS from a CausalLM.""" + with self.assertRaises(ValueError) as ctx: + self.engine.score( + query="Test", + items=["Item"], + label_token_ids=[1, 2], + return_pooled_hidden_states=True, + ) + self.assertIn("CausalLM", str(ctx.exception)) + + def test_causal_lm_without_phs_still_works(self): + """Baseline: CausalLM scoring without the flag works fine.""" + result = self.engine.score( + query="Test", + items=["Item"], + label_token_ids=[1, 2], + apply_softmax=True, + return_pooled_hidden_states=False, + ) + self.assertEqual(len(result.scores), 1) + self.assertIsNone(result.pooled_hidden_states) + + +# --------------------------------------------------------------------------- +# HTTP layer +# --------------------------------------------------------------------------- + + +class TestPooledHiddenStatesHTTP(CustomTestCase): + """HTTP integration: /v1/score with return_pooled_hidden_states. + + Validates that the Pydantic schema, JSON serialization, and ORJSONResponse + round-trip preserves the pooled hidden states as nested lists. + """ + + @classmethod + def setUpClass(cls): + cls.model = _SEQCLS_MODEL + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--disable-radix-cache", + "--json-model-override-args", + json.dumps( + { + "architectures": ["Qwen3ForSequenceClassification"], + "num_labels": _NUM_LABELS, + } + ), + "--mem-fraction-static", + "0.15", + ], + ) + + @classmethod + def tearDownClass(cls): + if hasattr(cls, "process") and cls.process: + kill_process_tree(cls.process.pid) + + def _post(self, payload): + return requests.post(self.base_url + "/v1/score", json=payload) + + def test_phs_in_response_json(self): + """Response includes pooled_hidden_states as nested float lists.""" + resp = self._post( + { + "query": "Rate each:", + "items": ["Good", "Bad"], + "return_pooled_hidden_states": True, + "model": self.model, + } + ) + self.assertEqual(resp.status_code, 200) + body = resp.json() + phs = body.get("pooled_hidden_states") + self.assertIsNotNone(phs) + self.assertEqual(len(phs), 2) + for item_phs in phs: + self.assertIsInstance(item_phs, list) + self.assertGreater(len(item_phs), 0) + for v in item_phs: + self.assertIsInstance(v, float) + + def test_phs_absent_when_not_requested(self): + """Without the flag, pooled_hidden_states is null in JSON.""" + resp = self._post( + { + "query": "Rate each:", + "items": ["Good"], + "model": self.model, + } + ) + self.assertEqual(resp.status_code, 200) + body = resp.json() + self.assertIsNone(body.get("pooled_hidden_states")) + + def test_phs_matches_item_count(self): + """Number of PHS vectors equals number of items.""" + items = ["A", "B", "C", "D"] + resp = self._post( + { + "query": "Classify:", + "items": items, + "return_pooled_hidden_states": True, + "model": self.model, + } + ) + self.assertEqual(resp.status_code, 200) + phs = resp.json()["pooled_hidden_states"] + self.assertEqual(len(phs), len(items)) + + +if __name__ == "__main__": + unittest.main(verbosity=3) diff --git a/test/registered/unit/layers/test_pooler_score_and_pool.py b/test/registered/unit/layers/test_pooler_score_and_pool.py index b40c5afca..2f0c4a1b9 100644 --- a/test/registered/unit/layers/test_pooler_score_and_pool.py +++ b/test/registered/unit/layers/test_pooler_score_and_pool.py @@ -23,13 +23,16 @@ from sglang.test.test_utils import CustomTestCase register_cpu_ci(est_time=9, suite="stage-a-test-cpu") -def _make_forward_batch(extend_seq_lens, is_prefill_only=False): +def _make_forward_batch( + extend_seq_lens, is_prefill_only=False, return_pooled_hidden_states=False +): """Build a minimal ForwardBatch stub for pooler unit tests.""" return SimpleNamespace( extend_seq_lens=torch.tensor(extend_seq_lens, dtype=torch.long), extend_seq_lens_cpu=extend_seq_lens, is_prefill_only=is_prefill_only, dimensions=None, + return_pooled_hidden_states=return_pooled_hidden_states, )