From 38ee749dd90cdab572d393bb856430ef92c981d2 Mon Sep 17 00:00:00 2001 From: Cheng Wan <54331508+ch-wan@users.noreply.github.com> Date: Fri, 20 Feb 2026 07:23:38 -0800 Subject: [PATCH] Fix adjust_num_token_non_padded_for_attn_tp returning CPU tensor (#19051) --- python/sglang/srt/model_executor/forward_batch_info.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index a321226e3..234523532 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -208,7 +208,7 @@ class CaptureHiddenMode(IntEnum): def compute_local_num_token_non_padded( - global_num_token_non_padded: torch.Tensor | int, + global_num_token_non_padded: torch.Tensor, num_tokens_per_dp: int, ) -> torch.Tensor: """Compute local non-padded token count for this attention-TP rank. @@ -220,10 +220,6 @@ def compute_local_num_token_non_padded( attn_tp_size = get_attention_tp_size() tokens_per_rank = num_tokens_per_dp // attn_tp_size - # Make sure global_num_token_non_padded is tensor so torch.clamp doesn't break - if isinstance(global_num_token_non_padded, int): - global_num_token_non_padded = torch.tensor(global_num_token_non_padded) - return torch.clamp( global_num_token_non_padded - tokens_per_rank * attn_tp_rank, 0, @@ -542,7 +538,7 @@ class ForwardBatch(ForwardBatchDeepSeekMHAMixin): num_tokens_per_dp = self.global_num_tokens_cpu[0] self.num_token_non_padded = compute_local_num_token_non_padded( - global_num_token_non_padded=self.num_token_non_padded_cpu, + global_num_token_non_padded=self.num_token_non_padded, num_tokens_per_dp=num_tokens_per_dp, )