mirror of
https://github.com/Physton/sd-webui-prompt-all-in-one.git
synced 2026-01-26 11:19:55 +00:00
Merge pull request #301 from serick4126/forge_support
Add cond_stage_model to get_prompt_lengths argument only when using f…
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from modules import script_callbacks, extra_networks, prompt_parser
|
||||
from modules import script_callbacks, extra_networks, prompt_parser, sd_models
|
||||
from modules.sd_hijack import model_hijack
|
||||
from functools import partial, reduce
|
||||
|
||||
@@ -16,8 +16,22 @@ def get_token_counter(text, steps):
|
||||
# messages related to it in console
|
||||
prompt_schedules = [[[steps, text]]]
|
||||
|
||||
try:
|
||||
from modules_forge import forge_version
|
||||
forge = True
|
||||
|
||||
except:
|
||||
forge = False
|
||||
|
||||
flat_prompts = reduce(lambda list1, list2: list1 + list2, prompt_schedules)
|
||||
prompts = [prompt_text for step, prompt_text in flat_prompts]
|
||||
token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts],
|
||||
key=lambda args: args[0])
|
||||
|
||||
if forge:
|
||||
cond_stage_model = sd_models.model_data.sd_model.cond_stage_model
|
||||
token_count, max_length = max([model_hijack.get_prompt_lengths(prompt,cond_stage_model) for prompt in prompts],
|
||||
key=lambda args: args[0])
|
||||
else:
|
||||
token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts],
|
||||
key=lambda args: args[0])
|
||||
|
||||
return {"token_count": token_count, "max_length": max_length}
|
||||
|
||||
Reference in New Issue
Block a user