Respect special tokens in WS server lefttrim_token()

This commit is contained in:
turboderp
2024-06-24 02:59:49 +02:00
parent 6a8172cfce
commit ef455a7bb9

View File

@@ -87,11 +87,11 @@ def lefttrim_token(request, ws, server, response):
text = request["text"]
length = int(request["trimmed_length"])
ids = server.tokenizer.cached_encode_str(text)
ids = server.tokenizer.cached_encode_str(text, encode_special_tokens = True)
if ids.shape[-1] <= length:
response["trimmed_text"] = text
else:
response["trimmed_text"] = server.tokenizer.decode(ids[:, -length:])[0]
response["trimmed_text"] = server.tokenizer.decode(ids[:, -length:], decode_special_token = True)[0]
async def infer(request, ws, server, response):