From e72a6c411a839921ae81cf5e423b5e82f26e5a9b Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sat, 23 Nov 2024 17:31:01 +0900 Subject: [PATCH] fix missing infotext cased by conda cache some generation params such as TI hashes or Emphasis is added in sd_hijack / sd_hijack_clip if conda are fetche from cache sd_hijack_clip will not be executed and it won't have a chance to to add generation params the generation params will also be missing if in non low-vram mode because the hijack.extra_generation_params was never read after calculate_hr_conds --- modules/processing.py | 27 +++++++++++++++++++++++++-- modules/sd_hijack.py | 8 ++++++++ modules/sd_hijack_clip.py | 33 ++++++++++++++++++++++++++++----- modules/util.py | 15 +++++++++++++++ 4 files changed, 76 insertions(+), 7 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 92c3582cc..690533d49 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -187,6 +187,7 @@ class StableDiffusionProcessing: cached_uc = [None, None] cached_c = [None, None] + hijack_generation_params_state_list = [] comments: dict = None sampler: sd_samplers_common.Sampler | None = field(default=None, init=False) @@ -480,6 +481,10 @@ class StableDiffusionProcessing: for cache in caches: if cache[0] is not None and cached_params == cache[0]: + if len(cache) == 3: + generation_params_state, cached_params_2 = cache[2] + if cached_params == cached_params_2: + self.hijack_generation_params_state_list.extend(generation_params_state) return cache[1] cache = caches[0] @@ -487,9 +492,25 @@ class StableDiffusionProcessing: with devices.autocast(): cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling) + generation_params_state = model_hijack.capture_generation_params_state() + self.hijack_generation_params_state_list.extend(generation_params_state) + if len(cache) == 2: + cache.append((generation_params_state, cached_params)) + else: + cache[2] = (generation_params_state, cached_params) + cache[0] = cached_params return cache[1] + def apply_hijack_generation_params(self): + self.extra_generation_params.update(model_hijack.extra_generation_params) + for func in self.hijack_generation_params_state_list: + try: + func(self.extra_generation_params) + except Exception: + errors.report(f"Failed to apply hijack generation params state", exc_info=True) + self.hijack_generation_params_state_list.clear() + def setup_conds(self): prompts = prompt_parser.SdConditioning(self.prompts, width=self.width, height=self.height) negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height, is_negative_prompt=True) @@ -502,6 +523,8 @@ class StableDiffusionProcessing: self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, total_steps, [self.cached_uc], self.extra_network_data) self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, total_steps, [self.cached_c], self.extra_network_data) + self.apply_hijack_generation_params() + def get_conds(self): return self.c, self.uc @@ -965,8 +988,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: p.setup_conds() - p.extra_generation_params.update(model_hijack.extra_generation_params) - # params.txt should be saved after scripts.process_batch, since the # infotext could be modified by that callback # Example: a wildcard processed by process_batch sets an extra model @@ -1513,6 +1534,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.firstpass_steps, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data, total_steps) self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.firstpass_steps, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data, total_steps) + self.apply_hijack_generation_params() + def setup_conds(self): if self.is_hr_pass: # if we are in hr pass right now, the call is being made from the refiner, and we don't need to setup firstpass cons or switch model diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 0de830541..05d24fc01 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -6,6 +6,7 @@ from modules import devices, sd_hijack_optimizations, shared, script_callbacks, from modules.hypernetworks import hypernetwork from modules.shared import cmd_opts from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m18 +from modules.util import GenerationParamsState import ldm.modules.attention import ldm.modules.diffusionmodules.model @@ -321,6 +322,13 @@ class StableDiffusionModelHijack: self.comments = [] self.extra_generation_params = {} + def capture_generation_params_state(self): + state = [] + for key in list(self.extra_generation_params): + if isinstance(self.extra_generation_params[key], GenerationParamsState): + state.append(self.extra_generation_params.pop(key)) + return state + def get_prompt_lengths(self, text): if self.clip is None: return "-", "-" diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index a479148fc..0cd23fc39 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -5,6 +5,7 @@ import torch from modules import prompt_parser, devices, sd_hijack, sd_emphasis from modules.shared import opts +from modules.util import GenerationParamsState class PromptChunk: @@ -27,6 +28,31 @@ chunk. Those objects are found in PromptChunk.fixes and, are placed into FrozenC are applied by sd_hijack.EmbeddingsWithFixes's forward function.""" +class EmbeddingHashes(GenerationParamsState): + def __init__(self, hashes: list): + super().__init__() + self.hashes = hashes + + def __call__(self, extra_generation_params): + unique_hashes = dict.fromkeys(self.hashes) + if existing_ti_hashes := extra_generation_params.get('TI hashes'): + unique_hashes.update(dict.fromkeys(existing_ti_hashes.split(', '))) + extra_generation_params['TI hashes'] = ', '.join(unique_hashes) + + +class EmphasisMode(GenerationParamsState): + def __init__(self, texts): + super().__init__() + if opts.emphasis != 'Original' and any(x for x in texts if '(' in x or '[' in x): + self.emphasis = opts.emphasis + else: + self.emphasis = None + + def __call__(self, extra_generation_params): + if self.emphasis: + extra_generation_params['Emphasis'] = self.emphasis + + class TextConditionalModel(torch.nn.Module): def __init__(self): super().__init__() @@ -238,12 +264,9 @@ class TextConditionalModel(torch.nn.Module): hashes.append(f"{name}: {shorthash}") if hashes: - if self.hijack.extra_generation_params.get("TI hashes"): - hashes.append(self.hijack.extra_generation_params.get("TI hashes")) - self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes) + self.hijack.extra_generation_params["TI hashes"] = EmbeddingHashes(hashes) - if any(x for x in texts if "(" in x or "[" in x) and opts.emphasis != "Original": - self.hijack.extra_generation_params["Emphasis"] = opts.emphasis + self.hijack.extra_generation_params["Emphasis"] = EmphasisMode(texts) if self.return_pooled: return torch.hstack(zs), zs[0].pooled diff --git a/modules/util.py b/modules/util.py index baeba2fa2..a8452c0e0 100644 --- a/modules/util.py +++ b/modules/util.py @@ -288,3 +288,18 @@ def compare_sha256(file_path: str, hash_prefix: str) -> bool: for chunk in iter(lambda: f.read(blksize), b""): hash_sha256.update(chunk) return hash_sha256.hexdigest().startswith(hash_prefix.strip().lower()) + + +class GenerationParamsState: + """A custom class used in StableDiffusionModelHijack for assigning extra_generation_params + generation_params assigned using this class will work properly with StableDiffusionProcessing.get_conds_with_caching() + if assigned directly the generation_params will not be populated if conda cache is used + + Generation_params of this class will be captured (see StableDiffusionModelHijack.capture_generation_params_state) and stored with conda cache, and will be extracted in StableDiffusionProcessing.apply_hijack_generation_params() + + To use this class, create a subclass with a __call__ method that takes extra_generation_params: dict as input + + Example usage: sd_hijack_clip.EmbeddingHashes, sd_hijack_clip.EmphasisMode + """ + def __call__(self, extra_generation_params: dict): + raise NotImplementedError