mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-14 15:57:27 +00:00
Api: Add token endpoints
Support for encoding and decoding with various parameters. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
16
model.py
16
model.py
@@ -11,7 +11,7 @@ from exllamav2.generator import(
|
||||
ExLlamaV2StreamingGenerator,
|
||||
ExLlamaV2Sampler
|
||||
)
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
# Bytes to reserve on first device when loading with auto split
|
||||
auto_split_reserve_bytes = 96 * 1024**2
|
||||
@@ -195,6 +195,20 @@ class ModelContainer:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
# Common function for token operations
|
||||
def get_tokens(self, text: Optional[str], ids: Optional[List[int]], **kwargs):
|
||||
if text:
|
||||
# Assume token encoding
|
||||
return self.tokenizer.encode(
|
||||
text, add_bos = kwargs.get("add_bos", True),
|
||||
encode_special_tokens = kwargs.get("encode_special_tokens", True)
|
||||
)
|
||||
if ids:
|
||||
# Assume token decoding
|
||||
ids = torch.tensor([ids])
|
||||
return self.tokenizer.decode(ids, decode_special_tokens = kwargs.get("decode_special_tokens", True))[0]
|
||||
|
||||
|
||||
def generate(self, prompt: str, **kwargs):
|
||||
gen = self.generate_gen(prompt, **kwargs)
|
||||
reponse = "".join(gen)
|
||||
|
||||
Reference in New Issue
Block a user