mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-21 23:44:04 +00:00
60 lines
2.4 KiB
Python
60 lines
2.4 KiB
Python
import torch
|
|
import os
|
|
|
|
class SPieceTokenizer:
|
|
@staticmethod
|
|
def from_pretrained(path, **kwargs):
|
|
return SPieceTokenizer(path, **kwargs)
|
|
|
|
def __init__(self, tokenizer_path, add_bos=False, add_eos=True, special_tokens=None):
|
|
self.add_bos = add_bos
|
|
self.add_eos = add_eos
|
|
self.special_tokens = special_tokens
|
|
import sentencepiece
|
|
if torch.is_tensor(tokenizer_path):
|
|
tokenizer_path = tokenizer_path.numpy().tobytes()
|
|
|
|
if isinstance(tokenizer_path, bytes):
|
|
self.tokenizer = sentencepiece.SentencePieceProcessor(model_proto=tokenizer_path, add_bos=self.add_bos, add_eos=self.add_eos)
|
|
else:
|
|
if not os.path.isfile(tokenizer_path):
|
|
raise ValueError("invalid tokenizer")
|
|
self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path, add_bos=self.add_bos, add_eos=self.add_eos)
|
|
|
|
def get_vocab(self):
|
|
out = {}
|
|
for i in range(self.tokenizer.get_piece_size()):
|
|
out[self.tokenizer.id_to_piece(i)] = i
|
|
return out
|
|
|
|
def __call__(self, string):
|
|
if self.special_tokens is not None:
|
|
import re
|
|
special_tokens_pattern = '|'.join(re.escape(token) for token in self.special_tokens.keys())
|
|
if special_tokens_pattern and re.search(special_tokens_pattern, string):
|
|
parts = re.split(f'({special_tokens_pattern})', string)
|
|
result = []
|
|
for part in parts:
|
|
if not part:
|
|
continue
|
|
if part in self.special_tokens:
|
|
result.append(self.special_tokens[part])
|
|
else:
|
|
encoded = self.tokenizer.encode(part, add_bos=False, add_eos=False)
|
|
result.extend(encoded)
|
|
return {"input_ids": result}
|
|
|
|
out = self.tokenizer.encode(string)
|
|
return {"input_ids": out}
|
|
|
|
def decode(self, token_ids, skip_special_tokens=False):
|
|
|
|
if skip_special_tokens and self.special_tokens:
|
|
special_token_ids = set(self.special_tokens.values())
|
|
token_ids = [tid for tid in token_ids if tid not in special_token_ids]
|
|
|
|
return self.tokenizer.decode(token_ids)
|
|
|
|
def serialize_model(self):
|
|
return torch.ByteTensor(list(self.tokenizer.serialized_model_proto()))
|