mirror of
https://github.com/kvcache-ai/sglang.git
synced 2026-07-01 12:17:09 +00:00
42 lines
1.2 KiB
Python
42 lines
1.2 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING
|
|
|
|
import torch
|
|
|
|
from sglang.jit_kernel.utils import cache_once, load_jit, make_cpp_args
|
|
|
|
if TYPE_CHECKING:
|
|
from tvm_ffi.module import Module
|
|
|
|
|
|
@cache_once
|
|
def _jit_resolve_future_token_ids_module(dtype: torch.dtype) -> Module:
|
|
"""Compile and cache the JIT module for a given dtype."""
|
|
args = make_cpp_args(dtype)
|
|
return load_jit(
|
|
"resolve_future_token_ids",
|
|
*args,
|
|
cuda_files=["elementwise/resolve_future_token_ids.cuh"],
|
|
cuda_wrappers=[
|
|
(
|
|
"resolve_future_token_ids",
|
|
f"ResolveFutureTokenIds<{args}>::run",
|
|
)
|
|
],
|
|
)
|
|
|
|
|
|
def resolve_future_token_ids_cuda(
|
|
input_ids: torch.Tensor, future_token_ids_map: torch.Tensor
|
|
) -> None:
|
|
"""Resolve future token IDs in-place on CUDA.
|
|
|
|
For each negative value in input_ids, replaces it with
|
|
future_token_ids_map[-value]. Non-negative values are unchanged.
|
|
|
|
Supported dtypes: torch.int32, torch.int64.
|
|
"""
|
|
module = _jit_resolve_future_token_ids_module(input_ids.dtype)
|
|
module.resolve_future_token_ids(input_ids, future_token_ids_map)
|