This commit is contained in:
lllyasviel
2024-01-25 08:47:34 -08:00
parent 08c0fec7ce
commit 82254870fd
2 changed files with 10 additions and 8 deletions

View File

@@ -160,13 +160,9 @@ class StableDiffusionModelHijack:
self.comments = []
self.extra_generation_params = {}
def get_prompt_lengths(self, text):
if self.clip is None:
return "-", "-"
_, token_count = self.clip.process_texts([text])
return token_count, self.clip.get_target_prompt_token_count(token_count)
def get_prompt_lengths(self, text, cond_stage_model):
_, token_count = cond_stage_model.process_texts([text])
return token_count, cond_stage_model.get_target_prompt_token_count(token_count)
def redo_hijack(self, m):
pass

View File

@@ -167,9 +167,15 @@ def update_token_counter(text, steps, *, is_positive=True):
# messages related to it in console
prompt_schedules = [[[steps, text]]]
try:
cond_stage_model = sd_models.model_data.sd_model.cond_stage_model
assert cond_stage_model is not None
except Exception:
return f"<span class='gr-box gr-text-input'>?/?</span>"
flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
prompts = [prompt_text for step, prompt_text in flat_prompts]
token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0])
token_count, max_length = max([model_hijack.get_prompt_lengths(prompt, cond_stage_model) for prompt in prompts], key=lambda args: args[0])
return f"<span class='gr-box gr-text-input'>{token_count}/{max_length}</span>"