mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-04-20 06:19:10 +00:00
tokenizer: harden HF chat template compatibility and kwargs passthrough
This commit is contained in:
@@ -626,7 +626,8 @@ class Tokenizer:
|
||||
def hf_chat_template(
|
||||
self,
|
||||
messages: list,
|
||||
add_generation_prompt: bool = True
|
||||
add_generation_prompt: bool = True,
|
||||
**template_kwargs
|
||||
):
|
||||
"""
|
||||
Tokenize with HF tokenizer. Requires `transformers`
|
||||
@@ -641,16 +642,77 @@ class Tokenizer:
|
||||
:param add_generation_prompt:
|
||||
bool, add generation prompt
|
||||
|
||||
:param template_kwargs:
|
||||
Additional kwargs forwarded to `transformers` chat template rendering,
|
||||
e.g. `tools`, `chat_template_kwargs`, `enable_thinking`, etc.
|
||||
|
||||
:return:
|
||||
Topken IDs tensor, shape (1, num_tokens)
|
||||
Token IDs tensor, shape (1, num_tokens)
|
||||
"""
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
if self.hf_tokenizer is None:
|
||||
self.hf_tokenizer = AutoTokenizer.from_pretrained(self.config.directory)
|
||||
ids = self.hf_tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt = add_generation_prompt
|
||||
)["input_ids"]
|
||||
|
||||
try:
|
||||
rendered = self.hf_tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt = add_generation_prompt,
|
||||
tokenize = True,
|
||||
return_dict = True,
|
||||
**template_kwargs,
|
||||
)
|
||||
except TypeError:
|
||||
# Compatibility for older `transformers` versions without
|
||||
# `return_dict` support on apply_chat_template.
|
||||
rendered = self.hf_tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt = add_generation_prompt,
|
||||
tokenize = True,
|
||||
**template_kwargs,
|
||||
)
|
||||
|
||||
# HF tokenizer versions/models vary:
|
||||
# - dict with "input_ids"
|
||||
# - plain token list
|
||||
# - tensor
|
||||
if isinstance(rendered, dict):
|
||||
ids = rendered.get("input_ids")
|
||||
else:
|
||||
ids = rendered
|
||||
if ids is None:
|
||||
raise ValueError("HF chat template returned no input_ids.")
|
||||
|
||||
if isinstance(ids, torch.Tensor):
|
||||
if ids.ndim == 1:
|
||||
ids = ids.unsqueeze(0)
|
||||
return ids.to(dtype = torch.long)
|
||||
|
||||
if len(ids) > 0 and isinstance(ids[0], list):
|
||||
ids = ids[0]
|
||||
ids = torch.tensor(ids, dtype = torch.long).unsqueeze(0)
|
||||
return ids
|
||||
|
||||
def hf_render_chat_template(
|
||||
self,
|
||||
messages: list,
|
||||
add_generation_prompt: bool = True,
|
||||
**template_kwargs
|
||||
) -> str:
|
||||
"""
|
||||
Render HF chat template to text (no tokenization).
|
||||
|
||||
Useful for debugging or external parser layers that operate on text.
|
||||
"""
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
if self.hf_tokenizer is None:
|
||||
self.hf_tokenizer = AutoTokenizer.from_pretrained(self.config.directory)
|
||||
|
||||
rendered = self.hf_tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt = add_generation_prompt,
|
||||
tokenize = False,
|
||||
**template_kwargs,
|
||||
)
|
||||
return rendered
|
||||
|
||||
Reference in New Issue
Block a user