Fix streaming logprobs corruption caused by shared mutable list reference (#21030)

This commit is contained in:
Lianmin Zheng
2026-03-21 00:18:48 -07:00
committed by GitHub
parent d089db0563
commit dba6fb3d30
4 changed files with 36 additions and 48 deletions

View File

@@ -650,22 +650,19 @@ class OpenAIServingChat(OpenAIServingBase):
routed_experts[index] = content["meta_info"].get("routed_experts", None)
# Handle logprobs
finish_reason = content["meta_info"].get("finish_reason", None)
choice_logprobs = None
if request.logprobs:
n_prev_token = n_prev_tokens.get(index, 0)
total_output_logprobs = len(
content["meta_info"]["output_token_logprobs"]
)
# When finish_reason is set and all logprobs have been sent,
# any remaining text is just buffered text being flushed by the
# detokenizer (it holds back text at word boundaries). Return None
# for logprobs since no new tokens were generated for this text.
if n_prev_token < total_output_logprobs or finish_reason is None:
total_output_logprobs = content["meta_info"][
"output_token_logprobs_length"
]
if n_prev_token < total_output_logprobs:
choice_logprobs = self._process_streaming_logprobs(
content, n_prev_token
content, n_prev_token, total_output_logprobs
)
n_prev_tokens[index] = total_output_logprobs
finish_reason = content["meta_info"].get("finish_reason", None)
finish_reason_type = finish_reason["type"] if finish_reason else None
# Track finish_reason for each index
@@ -1174,15 +1171,18 @@ class OpenAIServingChat(OpenAIServingBase):
return ToolCallProcessingResult(None, text, finish_reason)
def _process_streaming_logprobs(
self, content: Dict[str, Any], n_prev_token: int
self,
content: Dict[str, Any],
n_prev_token: int,
total_output_logprobs: int,
) -> ChoiceLogprobs:
"""Process logprobs for streaming response"""
logprobs = to_openai_style_logprobs(
output_token_logprobs=content["meta_info"]["output_token_logprobs"][
n_prev_token:
n_prev_token:total_output_logprobs
],
output_top_logprobs=content["meta_info"].get("output_top_logprobs", [])[
n_prev_token:
n_prev_token:total_output_logprobs
],
)

View File

@@ -244,32 +244,22 @@ class OpenAIServingCompletion(OpenAIServingBase):
input_top_logprobs = None
n_prev_token = n_prev_tokens.get(index, 0)
total_output_logprobs = len(
content["meta_info"]["output_token_logprobs"]
)
output_logprobs_slice = content["meta_info"][
"output_token_logprobs"
][n_prev_token:]
finish_reason_for_logprobs = content["meta_info"]["finish_reason"]
# When finish_reason is set and all logprobs have been sent,
# any remaining text is just buffered text being flushed by the
# detokenizer (it holds back text at word boundaries). Return None
# for logprobs since no new tokens were generated for this text.
total_output_logprobs = content["meta_info"][
"output_token_logprobs_length"
]
if (
len(output_logprobs_slice) == 0
and finish_reason_for_logprobs is not None
and input_token_logprobs is None
n_prev_token < total_output_logprobs
or input_token_logprobs is not None
):
logprobs = None
else:
logprobs = to_openai_style_logprobs(
input_token_logprobs=input_token_logprobs,
input_top_logprobs=input_top_logprobs,
output_token_logprobs=output_logprobs_slice,
output_token_logprobs=content["meta_info"][
"output_token_logprobs"
][n_prev_token:total_output_logprobs],
output_top_logprobs=content["meta_info"].get(
"output_top_logprobs", []
)[n_prev_token:],
)[n_prev_token:total_output_logprobs],
)
n_prev_tokens[index] = total_output_logprobs

View File

@@ -1719,6 +1719,7 @@ class TokenizerManager(TokenizerCommunicatorMixin, TokenizerManagerMultiItemMixi
meta_info["input_token_logprobs"] = state.input_token_logprobs
meta_info["output_token_logprobs"] = state.output_token_logprobs
meta_info["output_token_logprobs_length"] = len(state.output_token_logprobs)
# 2. Handle top logprobs
if top_logprobs_num > 0: