diff --git a/backends/exllamav3/model.py b/backends/exllamav3/model.py index 76360df..a9781ec 100644 --- a/backends/exllamav3/model.py +++ b/backends/exllamav3/model.py @@ -18,6 +18,7 @@ from common.multimodal import MultimodalEmbeddingWrapper from common.sampling import BaseSamplerRequest from common.templating import PromptTemplate from common.transformers_utils import GenerationConfig +from common.utils import unwrap from endpoints.core.types.model import ModelCard from exllamav3 import Config, Model, Cache, Tokenizer @@ -175,7 +176,11 @@ class ExllamaV3Container(BaseModelContainer): A list of integer token IDs. """ - pass + return self.tokenizer.encode( + text, + add_bos=unwrap(kwargs.get("add_bos_token"), True), + encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True), + ).flatten().tolist() def decode_tokens(self, ids: List[int], **kwargs) -> str: """ @@ -189,9 +194,15 @@ class ExllamaV3Container(BaseModelContainer): The decoded text string. """ - pass + ids = torch.tensor([ids]) + return self.tokenizer.decode( + ids, + decode_special_tokens=unwrap(kwargs.get("decode_special_tokens"), True), + )[0] - def get_special_tokens(self, **kwargs) -> Dict[str, Any]: + def get_special_tokens( + self, add_bos_token: bool = True, ban_eos_token: bool = False + ): """ Gets special tokens used by the model/tokenizer. @@ -203,7 +214,12 @@ class ExllamaV3Container(BaseModelContainer): to their string or ID representation. """ - pass + return { + "bos_token": self.tokenizer.bos_token if add_bos_token else "", + "eos_token": self.tokenizer.eos_token if not ban_eos_token else "", + "pad_token": self.tokenizer.pad_token, + "unk_token": self.tokenizer.unk_token, + } def model_info(self) -> ModelCard: """