mirror of
https://github.com/kvcache-ai/sglang.git
synced 2026-07-03 13:57:04 +00:00
Fix adjust_num_token_non_padded_for_attn_tp returning CPU tensor (#19051)
This commit is contained in:
@@ -208,7 +208,7 @@ class CaptureHiddenMode(IntEnum):
|
||||
|
||||
|
||||
def compute_local_num_token_non_padded(
|
||||
global_num_token_non_padded: torch.Tensor | int,
|
||||
global_num_token_non_padded: torch.Tensor,
|
||||
num_tokens_per_dp: int,
|
||||
) -> torch.Tensor:
|
||||
"""Compute local non-padded token count for this attention-TP rank.
|
||||
@@ -220,10 +220,6 @@ def compute_local_num_token_non_padded(
|
||||
attn_tp_size = get_attention_tp_size()
|
||||
tokens_per_rank = num_tokens_per_dp // attn_tp_size
|
||||
|
||||
# Make sure global_num_token_non_padded is tensor so torch.clamp doesn't break
|
||||
if isinstance(global_num_token_non_padded, int):
|
||||
global_num_token_non_padded = torch.tensor(global_num_token_non_padded)
|
||||
|
||||
return torch.clamp(
|
||||
global_num_token_non_padded - tokens_per_rank * attn_tp_rank,
|
||||
0,
|
||||
@@ -542,7 +538,7 @@ class ForwardBatch(ForwardBatchDeepSeekMHAMixin):
|
||||
num_tokens_per_dp = self.global_num_tokens_cpu[0]
|
||||
|
||||
self.num_token_non_padded = compute_local_num_token_non_padded(
|
||||
global_num_token_non_padded=self.num_token_non_padded_cpu,
|
||||
global_num_token_non_padded=self.num_token_non_padded,
|
||||
num_tokens_per_dp=num_tokens_per_dp,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user