Fix indentation, refactor a bit

This commit is contained in:
turboderp
2024-02-01 06:02:18 +01:00
parent 56b0a1261f
commit 9437f3b3e0

View File

@@ -319,8 +319,16 @@ class ExLlamaV2Tokenizer:
# Decode IDs, or a list of IDs
def decode(self, ids, decode_special_tokens = False):
if type(ids) == torch.Tensor:
if isinstance(ids, list):
texts = []
for i in ids:
texts.append(self.decode(i, decode_special_tokens))
return texts
assert isinstance(ids, torch.Tensor), "ids must be Tensor"
if ids.dim() > 1:
texts = []
@@ -334,11 +342,6 @@ class ExLlamaV2Tokenizer:
ids = ids.tolist()
text = self.decode_(ids, decode_special_tokens)
return text
elif type(ids) == list:
texts = []
for id in ids:
texts.append(self.decode(id, decode_special_tokens))
return texts
# Create padding mask