diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index aefb7704..608f25fc 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -160,13 +160,9 @@ class StableDiffusionModelHijack: self.comments = [] self.extra_generation_params = {} - def get_prompt_lengths(self, text): - if self.clip is None: - return "-", "-" - - _, token_count = self.clip.process_texts([text]) - - return token_count, self.clip.get_target_prompt_token_count(token_count) + def get_prompt_lengths(self, text, cond_stage_model): + _, token_count = cond_stage_model.process_texts([text]) + return token_count, cond_stage_model.get_target_prompt_token_count(token_count) def redo_hijack(self, m): pass diff --git a/modules/ui.py b/modules/ui.py index 177c6872..5744f192 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -167,9 +167,15 @@ def update_token_counter(text, steps, *, is_positive=True): # messages related to it in console prompt_schedules = [[[steps, text]]] + try: + cond_stage_model = sd_models.model_data.sd_model.cond_stage_model + assert cond_stage_model is not None + except Exception: + return f"?/?" + flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules) prompts = [prompt_text for step, prompt_text in flat_prompts] - token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0]) + token_count, max_length = max([model_hijack.get_prompt_lengths(prompt, cond_stage_model) for prompt in prompts], key=lambda args: args[0]) return f"{token_count}/{max_length}"