Add cond_stage_model to get_prompt_lengths argument only when using forge

This commit is contained in:
Serick
2024-02-14 13:42:13 +09:00
parent 22483d8d2b
commit 8a41850bff

View File

@@ -1,4 +1,4 @@
from modules import script_callbacks, extra_networks, prompt_parser
from modules import script_callbacks, extra_networks, prompt_parser, sd_models
from modules.sd_hijack import model_hijack
from functools import partial, reduce
@@ -16,8 +16,22 @@ def get_token_counter(text, steps):
# messages related to it in console
prompt_schedules = [[[steps, text]]]
try:
from modules_forge import forge_version
forge = True
except:
forge = False
flat_prompts = reduce(lambda list1, list2: list1 + list2, prompt_schedules)
prompts = [prompt_text for step, prompt_text in flat_prompts]
token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts],
key=lambda args: args[0])
if forge:
cond_stage_model = sd_models.model_data.sd_model.cond_stage_model
token_count, max_length = max([model_hijack.get_prompt_lengths(prompt,cond_stage_model) for prompt in prompts],
key=lambda args: args[0])
else:
token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts],
key=lambda args: args[0])
return {"token_count": token_count, "max_length": max_length}