mirror of
https://github.com/kvcache-ai/sglang.git
synced 2026-07-05 23:07:00 +00:00
179 lines
6.0 KiB
Python
179 lines
6.0 KiB
Python
import random
|
|
import unittest
|
|
from contextlib import contextmanager
|
|
|
|
from transformers import AutoTokenizer
|
|
|
|
from sglang.srt.utils.patch_tokenizer import (
|
|
_SpecialTokensCachePatcher,
|
|
unpatch_tokenizer,
|
|
)
|
|
from sglang.test.ci.ci_register import register_cpu_ci
|
|
|
|
register_cpu_ci(est_time=30, suite="default", nightly=True)
|
|
|
|
|
|
class TestPatchTokenizerEndToEndTest(unittest.TestCase):
|
|
def test_patched_produces_same_results_as_raw(self):
|
|
tokenizer = _load_tokenizer()
|
|
test_texts = self._generate_test_texts(tokenizer)
|
|
raw_results = self._run_tokenizer_ops(tokenizer, test_texts)
|
|
|
|
_SpecialTokensCachePatcher.patch(tokenizer)
|
|
patched_results = self._run_tokenizer_ops(tokenizer, test_texts)
|
|
unpatch_tokenizer(tokenizer)
|
|
|
|
self.assertEqual(raw_results, patched_results)
|
|
|
|
@classmethod
|
|
def _generate_test_texts(cls, tokenizer):
|
|
special_tokens = tokenizer.all_special_tokens
|
|
return [
|
|
"Hello, world!",
|
|
"This is a longer sentence with multiple words.",
|
|
"Numbers 12345 and symbols !@#$%",
|
|
" leading and trailing spaces ",
|
|
"\n\nMultiple\n\nNewlines\n\n",
|
|
*[f"Text with {tok} inside" for tok in special_tokens],
|
|
" ".join(special_tokens),
|
|
*[
|
|
cls._random_text_from_tokens(tokenizer, num_tokens=100)
|
|
for _ in range(5)
|
|
],
|
|
*[
|
|
cls._random_text_from_tokens(tokenizer, num_tokens=1000)
|
|
for _ in range(3)
|
|
],
|
|
]
|
|
|
|
@classmethod
|
|
def _random_text_from_tokens(cls, tokenizer, num_tokens):
|
|
token_ids = [
|
|
random.randint(0, tokenizer.vocab_size - 1) for _ in range(num_tokens)
|
|
]
|
|
return tokenizer.decode(token_ids)
|
|
|
|
@classmethod
|
|
def _run_tokenizer_ops(cls, tokenizer, texts):
|
|
encode_results = [tokenizer.encode(t) for t in texts]
|
|
batch_encode_results = tokenizer(texts)["input_ids"]
|
|
return {
|
|
"encode": encode_results,
|
|
"batch_encode": batch_encode_results,
|
|
"decode": [
|
|
tokenizer.decode(ids, skip_special_tokens=True)
|
|
for ids in encode_results
|
|
],
|
|
"batch_decode": tokenizer.batch_decode(
|
|
encode_results, skip_special_tokens=True
|
|
),
|
|
"special_tokens": tokenizer.all_special_tokens,
|
|
"special_ids": tokenizer.all_special_ids,
|
|
}
|
|
|
|
|
|
class TestPatchTokenizerUnitTest(unittest.TestCase):
|
|
def test_patch_unpatch_restores_original(self):
|
|
tokenizer = _load_tokenizer()
|
|
cls = type(tokenizer)
|
|
|
|
original_ids = _get_class_attr_ids(cls)
|
|
|
|
_SpecialTokensCachePatcher.patch(tokenizer)
|
|
self.assertTrue(getattr(cls, "_sglang_special_tokens_patched", False))
|
|
|
|
patched_ids = _get_class_attr_ids(cls)
|
|
changed_attrs = [
|
|
name
|
|
for name in original_ids
|
|
if name in patched_ids and patched_ids[name] != original_ids[name]
|
|
]
|
|
self.assertGreater(len(changed_attrs), 0, "Patch should change some attributes")
|
|
|
|
unpatch_tokenizer(tokenizer)
|
|
self.assertFalse(getattr(cls, "_sglang_special_tokens_patched", False))
|
|
|
|
restored_ids = _get_class_attr_ids(cls)
|
|
for name in original_ids:
|
|
if name.startswith("_sglang") or name.startswith("_original"):
|
|
continue
|
|
self.assertEqual(
|
|
restored_ids.get(name),
|
|
original_ids[name],
|
|
f"Attribute {name} should be restored to original",
|
|
)
|
|
|
|
def test_patch_caches_special_tokens(self):
|
|
with _patched_tokenizer() as tokenizer:
|
|
tokens1 = tokenizer.all_special_tokens
|
|
ids1 = tokenizer.all_special_ids
|
|
tokens2 = tokenizer.all_special_tokens
|
|
ids2 = tokenizer.all_special_ids
|
|
|
|
self.assertIs(tokens1, tokens2)
|
|
self.assertIs(ids1, ids2)
|
|
|
|
def test_patch_blocks_add_special_tokens(self):
|
|
with _patched_tokenizer() as tokenizer:
|
|
with self.assertRaises(AssertionError) as ctx:
|
|
tokenizer.add_special_tokens({"pad_token": "<pad>"})
|
|
self.assertIn(
|
|
"Cannot modify special tokens after patch", str(ctx.exception)
|
|
)
|
|
|
|
def test_patch_blocks_add_tokens_with_special_flag(self):
|
|
with _patched_tokenizer() as tokenizer:
|
|
with self.assertRaises(AssertionError) as ctx:
|
|
tokenizer.add_tokens(["<new>"], special_tokens=True)
|
|
self.assertIn("Cannot add special tokens after patch", str(ctx.exception))
|
|
|
|
tokenizer.add_tokens(["<regular>"], special_tokens=False)
|
|
|
|
def test_unpatch_clears_cache(self):
|
|
with _patched_tokenizer() as tokenizer:
|
|
_ = tokenizer.all_special_tokens
|
|
_ = tokenizer.all_special_ids
|
|
self.assertTrue(hasattr(tokenizer, "_sglang_cached_special_tokens"))
|
|
self.assertTrue(hasattr(tokenizer, "_sglang_cached_special_ids"))
|
|
|
|
self.assertFalse(hasattr(tokenizer, "_sglang_cached_special_tokens"))
|
|
self.assertFalse(hasattr(tokenizer, "_sglang_cached_special_ids"))
|
|
|
|
def test_double_patch_is_idempotent(self):
|
|
tokenizer = _load_tokenizer()
|
|
_SpecialTokensCachePatcher.patch(tokenizer)
|
|
_SpecialTokensCachePatcher.patch(tokenizer)
|
|
|
|
self.assertTrue(
|
|
getattr(type(tokenizer), "_sglang_special_tokens_patched", False)
|
|
)
|
|
|
|
unpatch_tokenizer(tokenizer)
|
|
|
|
|
|
def _get_class_attr_ids(cls):
|
|
return {
|
|
n: id(v.fget if isinstance(v, property) else v) for n, v in vars(cls).items()
|
|
}
|
|
|
|
|
|
def _load_tokenizer():
|
|
# The slowness is mainly observed in Kimi
|
|
return AutoTokenizer.from_pretrained(
|
|
"nvidia/Kimi-K2-Thinking-NVFP4", trust_remote_code=True
|
|
)
|
|
|
|
|
|
@contextmanager
|
|
def _patched_tokenizer():
|
|
tokenizer = _load_tokenizer()
|
|
_SpecialTokensCachePatcher.patch(tokenizer)
|
|
try:
|
|
yield tokenizer
|
|
finally:
|
|
unpatch_tokenizer(tokenizer)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|