From 284f20263f60d9f83d925cb88b04c5f7ab4727ae Mon Sep 17 00:00:00 2001 From: kingbri Date: Mon, 5 Feb 2024 00:20:10 -0500 Subject: [PATCH] API: Clean up tokenizing endpoint Split the get tokens function into separate wrapper encode and decode functions for overall code cleanliness. Signed-off-by: kingbri --- backends/exllamav2/model.py | 33 ++++++++++++++++----------------- main.py | 9 +++------ 2 files changed, 19 insertions(+), 23 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 590b47c..356e984 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -446,24 +446,23 @@ class ExllamaV2Container: gc.collect() torch.cuda.empty_cache() - def get_tokens(self, text: Optional[str], ids: Optional[List[int]], **kwargs): - """Common function for token operations""" - if text: - # Assume token encoding - return self.tokenizer.encode( - text, - add_bos=unwrap(kwargs.get("add_bos_token"), True), - encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True), - ) - if ids: - # Assume token decoding - ids = torch.tensor([ids]) - return self.tokenizer.decode( - ids, - decode_special_tokens=unwrap(kwargs.get("decode_special_tokens"), True), - )[0] + def encode_tokens(self, text: str, **kwargs): + """Wrapper to encode tokens from a text string""" - return None + return self.tokenizer.encode( + text, + add_bos=unwrap(kwargs.get("add_bos_token"), True), + encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True), + )[0].tolist() + + def decode_tokens(self, ids: List[int], **kwargs): + """Wrapper to decode tokens from a list of IDs""" + + ids = torch.tensor([ids]) + return self.tokenizer.decode( + ids, + decode_special_tokens=unwrap(kwargs.get("decode_special_tokens"), True), + )[0] def get_special_tokens(self, add_bos_token: bool, ban_eos_token: bool): return { diff --git a/main.py b/main.py index eaf045a..71a90da 100644 --- a/main.py +++ b/main.py @@ -414,11 +414,8 @@ async def unload_loras(): ) async def encode_tokens(data: TokenEncodeRequest): """Encodes a string into tokens.""" - raw_tokens = MODEL_CONTAINER.get_tokens(data.text, None, **data.get_params()) - - # Have to use this if check otherwise Torch's tensors error out - # with a boolean issue - tokens = raw_tokens[0].tolist() if raw_tokens is not None else [] + raw_tokens = MODEL_CONTAINER.encode_tokens(data.text, **data.get_params()) + tokens = unwrap(raw_tokens, []) response = TokenEncodeResponse(tokens=tokens, length=len(tokens)) return response @@ -431,7 +428,7 @@ async def encode_tokens(data: TokenEncodeRequest): ) async def decode_tokens(data: TokenDecodeRequest): """Decodes tokens into a string.""" - message = MODEL_CONTAINER.get_tokens(None, data.tokens, **data.get_params()) + message = MODEL_CONTAINER.decode_tokens(data.tokens, **data.get_params()) response = TokenDecodeResponse(text=unwrap(message, "")) return response