diff --git a/scripts/physton_prompt/get_token_counter.py b/scripts/physton_prompt/get_token_counter.py index a51ef69..0080284 100644 --- a/scripts/physton_prompt/get_token_counter.py +++ b/scripts/physton_prompt/get_token_counter.py @@ -1,4 +1,4 @@ -from modules import script_callbacks, extra_networks, prompt_parser +from modules import script_callbacks, extra_networks, prompt_parser, sd_models from modules.sd_hijack import model_hijack from functools import partial, reduce @@ -16,8 +16,22 @@ def get_token_counter(text, steps): # messages related to it in console prompt_schedules = [[[steps, text]]] + try: + from modules_forge import forge_version + forge = True + + except: + forge = False + flat_prompts = reduce(lambda list1, list2: list1 + list2, prompt_schedules) prompts = [prompt_text for step, prompt_text in flat_prompts] - token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], - key=lambda args: args[0]) + + if forge: + cond_stage_model = sd_models.model_data.sd_model.cond_stage_model + token_count, max_length = max([model_hijack.get_prompt_lengths(prompt,cond_stage_model) for prompt in prompts], + key=lambda args: args[0]) + else: + token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], + key=lambda args: args[0]) + return {"token_count": token_count, "max_length": max_length}