mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-28 10:11:39 +00:00
Exl3: Add token encode, decode, and special token fetch
Base class methods Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
@@ -18,6 +18,7 @@ from common.multimodal import MultimodalEmbeddingWrapper
|
|||||||
from common.sampling import BaseSamplerRequest
|
from common.sampling import BaseSamplerRequest
|
||||||
from common.templating import PromptTemplate
|
from common.templating import PromptTemplate
|
||||||
from common.transformers_utils import GenerationConfig
|
from common.transformers_utils import GenerationConfig
|
||||||
|
from common.utils import unwrap
|
||||||
from endpoints.core.types.model import ModelCard
|
from endpoints.core.types.model import ModelCard
|
||||||
|
|
||||||
from exllamav3 import Config, Model, Cache, Tokenizer
|
from exllamav3 import Config, Model, Cache, Tokenizer
|
||||||
@@ -175,7 +176,11 @@ class ExllamaV3Container(BaseModelContainer):
|
|||||||
A list of integer token IDs.
|
A list of integer token IDs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
pass
|
return self.tokenizer.encode(
|
||||||
|
text,
|
||||||
|
add_bos=unwrap(kwargs.get("add_bos_token"), True),
|
||||||
|
encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True),
|
||||||
|
).flatten().tolist()
|
||||||
|
|
||||||
def decode_tokens(self, ids: List[int], **kwargs) -> str:
|
def decode_tokens(self, ids: List[int], **kwargs) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -189,9 +194,15 @@ class ExllamaV3Container(BaseModelContainer):
|
|||||||
The decoded text string.
|
The decoded text string.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
pass
|
ids = torch.tensor([ids])
|
||||||
|
return self.tokenizer.decode(
|
||||||
|
ids,
|
||||||
|
decode_special_tokens=unwrap(kwargs.get("decode_special_tokens"), True),
|
||||||
|
)[0]
|
||||||
|
|
||||||
def get_special_tokens(self, **kwargs) -> Dict[str, Any]:
|
def get_special_tokens(
|
||||||
|
self, add_bos_token: bool = True, ban_eos_token: bool = False
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Gets special tokens used by the model/tokenizer.
|
Gets special tokens used by the model/tokenizer.
|
||||||
|
|
||||||
@@ -203,7 +214,12 @@ class ExllamaV3Container(BaseModelContainer):
|
|||||||
to their string or ID representation.
|
to their string or ID representation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
pass
|
return {
|
||||||
|
"bos_token": self.tokenizer.bos_token if add_bos_token else "",
|
||||||
|
"eos_token": self.tokenizer.eos_token if not ban_eos_token else "",
|
||||||
|
"pad_token": self.tokenizer.pad_token,
|
||||||
|
"unk_token": self.tokenizer.unk_token,
|
||||||
|
}
|
||||||
|
|
||||||
def model_info(self) -> ModelCard:
|
def model_info(self) -> ModelCard:
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user