Files
2026-04-20 16:34:50 +09:00

51 lines
1.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from modules import script_callbacks, extra_networks, prompt_parser, sd_models
from functools import reduce
# 嘗試載入 sd_hijack / model_hijackA1111 / Forge Classic 才有)
try:
from modules.sd_hijack import model_hijack
except (ImportError, ModuleNotFoundError):
model_hijack = None
def get_token_counter(text, steps):
try:
try:
text, _ = extra_networks.parse_prompt(text)
_, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps)
except Exception:
prompt_schedules = [[[steps, text]]]
# 判斷是否 Forge
try:
from modules_forge import forge_version
forge = True
except:
forge = False
flat_prompts = reduce(lambda list1, list2: list1 + list2, prompt_schedules)
prompts = [prompt_text for step, prompt_text in flat_prompts]
# 🚨 Forge Neo / 沒有 hijack直接停用 token counter
if model_hijack is None:
return {"token_count": 0, "max_length": 0}
# A1111 / Forge Classic
if forge:
cond_stage_model = sd_models.model_data.sd_model.cond_stage_model
token_count, max_length = max(
[model_hijack.get_prompt_lengths(prompt, cond_stage_model) for prompt in prompts],
key=lambda args: args[0]
)
else:
token_count, max_length = max(
[model_hijack.get_prompt_lengths(prompt) for prompt in prompts],
key=lambda args: args[0]
)
return {"token_count": token_count, "max_length": max_length}
except Exception:
return {"token_count": 0, "max_length": 0}