tokenizer: harden HF chat template compatibility and kwargs passthrough

This commit is contained in:
lesj0610
2026-02-12 01:25:30 +09:00
parent 701afb9294
commit 019d965eb6

View File

@@ -626,7 +626,8 @@ class Tokenizer:
def hf_chat_template(
self,
messages: list,
add_generation_prompt: bool = True
add_generation_prompt: bool = True,
**template_kwargs
):
"""
Tokenize with HF tokenizer. Requires `transformers`
@@ -641,16 +642,77 @@ class Tokenizer:
:param add_generation_prompt:
bool, add generation prompt
:param template_kwargs:
Additional kwargs forwarded to `transformers` chat template rendering,
e.g. `tools`, `chat_template_kwargs`, `enable_thinking`, etc.
:return:
Topken IDs tensor, shape (1, num_tokens)
Token IDs tensor, shape (1, num_tokens)
"""
from transformers import AutoTokenizer
if self.hf_tokenizer is None:
self.hf_tokenizer = AutoTokenizer.from_pretrained(self.config.directory)
ids = self.hf_tokenizer.apply_chat_template(
messages,
add_generation_prompt = add_generation_prompt
)["input_ids"]
try:
rendered = self.hf_tokenizer.apply_chat_template(
messages,
add_generation_prompt = add_generation_prompt,
tokenize = True,
return_dict = True,
**template_kwargs,
)
except TypeError:
# Compatibility for older `transformers` versions without
# `return_dict` support on apply_chat_template.
rendered = self.hf_tokenizer.apply_chat_template(
messages,
add_generation_prompt = add_generation_prompt,
tokenize = True,
**template_kwargs,
)
# HF tokenizer versions/models vary:
# - dict with "input_ids"
# - plain token list
# - tensor
if isinstance(rendered, dict):
ids = rendered.get("input_ids")
else:
ids = rendered
if ids is None:
raise ValueError("HF chat template returned no input_ids.")
if isinstance(ids, torch.Tensor):
if ids.ndim == 1:
ids = ids.unsqueeze(0)
return ids.to(dtype = torch.long)
if len(ids) > 0 and isinstance(ids[0], list):
ids = ids[0]
ids = torch.tensor(ids, dtype = torch.long).unsqueeze(0)
return ids
def hf_render_chat_template(
self,
messages: list,
add_generation_prompt: bool = True,
**template_kwargs
) -> str:
"""
Render HF chat template to text (no tokenization).
Useful for debugging or external parser layers that operate on text.
"""
from transformers import AutoTokenizer
if self.hf_tokenizer is None:
self.hf_tokenizer = AutoTokenizer.from_pretrained(self.config.directory)
rendered = self.hf_tokenizer.apply_chat_template(
messages,
add_generation_prompt = add_generation_prompt,
tokenize = False,
**template_kwargs,
)
return rendered