mirror of
https://github.com/kvcache-ai/sglang.git
synced 2026-06-30 19:57:52 +00:00
118 lines
4.9 KiB
Python
118 lines
4.9 KiB
Python
from typing import Optional
|
|
|
|
import torch
|
|
|
|
|
|
def fast_topk(values, topk, dim):
|
|
if topk == 1:
|
|
# Use max along the specified dimension to get both value and index
|
|
return torch.max(values, dim=dim, keepdim=True)
|
|
else:
|
|
# Use topk for efficiency with larger k values
|
|
# TODO: implement faster cuda kernels for large vocab sizes
|
|
return torch.topk(values, topk, dim=dim)
|
|
|
|
|
|
def fast_topk_v2(
|
|
score: torch.Tensor,
|
|
lengths: torch.Tensor,
|
|
topk: int,
|
|
row_starts: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Get the topk indices of the score tensor.
|
|
Args:
|
|
score: The score tensor of shape (B, L). The score tensor is the logits
|
|
between the query and the key whose layout is either ragged or paged.
|
|
row_starts is only required when the key is ragged.
|
|
lengths: The lengths tensor of shape (B)
|
|
topk: The number of topk indices to get
|
|
row_starts: The start index of each row in the score tensor of shape (B).
|
|
For each row i, topk only applies to section [row_starts[i], row_starts[i] + lengths[i]]
|
|
of the score tensor.
|
|
Returns:
|
|
The topk indices tensor of shape (B, topk)
|
|
"""
|
|
assert (
|
|
topk == 2048
|
|
), "fast_topk_v2 is only optimized for deepseek v3.2 model, where topk=2048"
|
|
assert score.dim() == 2
|
|
topk_indices = score.new_empty((score.size(0), topk), dtype=torch.int32)
|
|
torch.ops.sgl_kernel.fast_topk(score, topk_indices, lengths, row_starts)
|
|
return topk_indices
|
|
|
|
|
|
def fast_topk_transform_fused(
|
|
score: torch.Tensor,
|
|
lengths: torch.Tensor,
|
|
page_table_size_1: torch.Tensor, # NOTE: page size should be 1
|
|
cu_seqlens_q: torch.Tensor,
|
|
topk: int,
|
|
row_starts: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Get the topk indices of the score tensor and then transform the topk indices
|
|
to indices to the page table (page_size = 1)
|
|
Args:
|
|
score: The score tensor of shape (B, L). The score tensor is the logits
|
|
between the query and the key whose layout is either ragged or paged.
|
|
row_starts is only required when the key is ragged.
|
|
lengths: The lengths tensor of shape (B)
|
|
page_table_size_1: The page table tensor of shape (Batch, topk)
|
|
cu_seqlens_q: The cumulative sequence lengths tensor of shape (Batch + 1)
|
|
topk: The number of topk indices to get
|
|
row_starts: The start index of each row in the score tensor of shape (B).
|
|
For each row i, topk only applies to section [row_starts[i], row_starts[i] + lengths[i]]
|
|
of the score tensor. It's only used for cases where the key is
|
|
ragged, i.e. during extend and draft extend.
|
|
Returns:
|
|
The topk indices tensor of shape (B, topk)
|
|
"""
|
|
assert (
|
|
topk == 2048
|
|
), "fast_topk_transform_fused is only optimized for deepseek v3.2 model, where topk=2048"
|
|
assert score.dim() == 2
|
|
src_page_table = page_table_size_1
|
|
dst_page_table = score.new_empty((score.shape[0], topk), dtype=torch.int32)
|
|
torch.ops.sgl_kernel.fast_topk_transform_fused(
|
|
score, lengths, dst_page_table, src_page_table, cu_seqlens_q, row_starts
|
|
)
|
|
return dst_page_table
|
|
|
|
|
|
def fast_topk_transform_ragged_fused(
|
|
score: torch.Tensor,
|
|
lengths: torch.Tensor,
|
|
topk_indices_offset: torch.Tensor, # ragged kv
|
|
topk: int,
|
|
row_starts: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Get the topk indices of the score tensor and then transform the topk indices to
|
|
indices to ragged kv (non-paged). This function is only used for extend,
|
|
not including draft extend.
|
|
Args:
|
|
score: The score tensor of shape (B, L). The score tensor is the logits
|
|
between the query and the key which can be ragged or paged.
|
|
row_starts is only required when the key is ragged.
|
|
lengths: The lengths tensor of shape (B)
|
|
topk_indices_offset: The offset of topk indices in ragged kv of shape (B)
|
|
topk: The number of topk indices to get
|
|
row_starts: The start index of each row in the score tensor of shape (B).
|
|
For each row i, topk only applies to section [row_starts[i], row_starts[i] + lengths[i]]
|
|
of the score tensor. It can be None if only the fast path is triggered,
|
|
in the case of all values in lengths <= topk (not checked in the kernel,
|
|
guaranteed by the caller).
|
|
Returns:
|
|
The topk indices tensor of shape (B, topk)
|
|
"""
|
|
assert (
|
|
topk == 2048
|
|
), "fast_topk_transform_ragged_fused is only optimized for deepseek v3.2 model, where topk=2048"
|
|
assert score.dim() == 2
|
|
topk_indices_ragged = score.new_empty((score.shape[0], topk), dtype=torch.int32)
|
|
torch.ops.sgl_kernel.fast_topk_transform_ragged_fused(
|
|
score, lengths, topk_indices_ragged, topk_indices_offset, row_starts
|
|
)
|
|
return topk_indices_ragged
|