Fix adjust_num_token_non_padded_for_attn_tp returning CPU tensor (#19051)

This commit is contained in:
Cheng Wan
2026-02-20 07:23:38 -08:00
committed by GitHub
parent 3358ba8945
commit 38ee749dd9

View File

@@ -208,7 +208,7 @@ class CaptureHiddenMode(IntEnum):
def compute_local_num_token_non_padded(
global_num_token_non_padded: torch.Tensor | int,
global_num_token_non_padded: torch.Tensor,
num_tokens_per_dp: int,
) -> torch.Tensor:
"""Compute local non-padded token count for this attention-TP rank.
@@ -220,10 +220,6 @@ def compute_local_num_token_non_padded(
attn_tp_size = get_attention_tp_size()
tokens_per_rank = num_tokens_per_dp // attn_tp_size
# Make sure global_num_token_non_padded is tensor so torch.clamp doesn't break
if isinstance(global_num_token_non_padded, int):
global_num_token_non_padded = torch.tensor(global_num_token_non_padded)
return torch.clamp(
global_num_token_non_padded - tokens_per_rank * attn_tp_rank,
0,
@@ -542,7 +538,7 @@ class ForwardBatch(ForwardBatchDeepSeekMHAMixin):
num_tokens_per_dp = self.global_num_tokens_cpu[0]
self.num_token_non_padded = compute_local_num_token_non_padded(
global_num_token_non_padded=self.num_token_non_padded_cpu,
global_num_token_non_padded=self.num_token_non_padded,
num_tokens_per_dp=num_tokens_per_dp,
)