mirror of
https://github.com/kvcache-ai/sglang.git
synced 2026-06-30 11:48:01 +00:00
12 lines
381 B
Python
12 lines
381 B
Python
import torch
|
|
|
|
|
|
def fast_topk(values, topk, dim):
|
|
if topk == 1:
|
|
# Use max along the specified dimension to get both value and index
|
|
return torch.max(values, dim=dim, keepdim=True)
|
|
else:
|
|
# Use topk for efficiency with larger k values
|
|
# TODO: implement faster cuda kernels for large vocab sizes
|
|
return torch.topk(values, topk, dim=dim)
|