diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index aefb7704..608f25fc 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -160,13 +160,9 @@ class StableDiffusionModelHijack:
self.comments = []
self.extra_generation_params = {}
- def get_prompt_lengths(self, text):
- if self.clip is None:
- return "-", "-"
-
- _, token_count = self.clip.process_texts([text])
-
- return token_count, self.clip.get_target_prompt_token_count(token_count)
+ def get_prompt_lengths(self, text, cond_stage_model):
+ _, token_count = cond_stage_model.process_texts([text])
+ return token_count, cond_stage_model.get_target_prompt_token_count(token_count)
def redo_hijack(self, m):
pass
diff --git a/modules/ui.py b/modules/ui.py
index 177c6872..5744f192 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -167,9 +167,15 @@ def update_token_counter(text, steps, *, is_positive=True):
# messages related to it in console
prompt_schedules = [[[steps, text]]]
+ try:
+ cond_stage_model = sd_models.model_data.sd_model.cond_stage_model
+ assert cond_stage_model is not None
+ except Exception:
+ return f"?/?"
+
flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
prompts = [prompt_text for step, prompt_text in flat_prompts]
- token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0])
+ token_count, max_length = max([model_hijack.get_prompt_lengths(prompt, cond_stage_model) for prompt in prompts], key=lambda args: args[0])
return f"{token_count}/{max_length}"