fix missing infotext cased by conda cache

some generation params such as TI hashes or Emphasis is added in sd_hijack / sd_hijack_clip
if conda are fetche from cache sd_hijack_clip will not be executed and it won't have a chance to to add generation params

the generation params will also be missing if in non low-vram mode because the hijack.extra_generation_params was never read after calculate_hr_conds
This commit is contained in:
w-e-w
2024-11-23 17:31:01 +09:00
parent 023454b49e
commit e72a6c411a
4 changed files with 76 additions and 7 deletions

View File

@@ -187,6 +187,7 @@ class StableDiffusionProcessing:
cached_uc = [None, None]
cached_c = [None, None]
hijack_generation_params_state_list = []
comments: dict = None
sampler: sd_samplers_common.Sampler | None = field(default=None, init=False)
@@ -480,6 +481,10 @@ class StableDiffusionProcessing:
for cache in caches:
if cache[0] is not None and cached_params == cache[0]:
if len(cache) == 3:
generation_params_state, cached_params_2 = cache[2]
if cached_params == cached_params_2:
self.hijack_generation_params_state_list.extend(generation_params_state)
return cache[1]
cache = caches[0]
@@ -487,9 +492,25 @@ class StableDiffusionProcessing:
with devices.autocast():
cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling)
generation_params_state = model_hijack.capture_generation_params_state()
self.hijack_generation_params_state_list.extend(generation_params_state)
if len(cache) == 2:
cache.append((generation_params_state, cached_params))
else:
cache[2] = (generation_params_state, cached_params)
cache[0] = cached_params
return cache[1]
def apply_hijack_generation_params(self):
self.extra_generation_params.update(model_hijack.extra_generation_params)
for func in self.hijack_generation_params_state_list:
try:
func(self.extra_generation_params)
except Exception:
errors.report(f"Failed to apply hijack generation params state", exc_info=True)
self.hijack_generation_params_state_list.clear()
def setup_conds(self):
prompts = prompt_parser.SdConditioning(self.prompts, width=self.width, height=self.height)
negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height, is_negative_prompt=True)
@@ -502,6 +523,8 @@ class StableDiffusionProcessing:
self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, total_steps, [self.cached_uc], self.extra_network_data)
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, total_steps, [self.cached_c], self.extra_network_data)
self.apply_hijack_generation_params()
def get_conds(self):
return self.c, self.uc
@@ -965,8 +988,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
p.setup_conds()
p.extra_generation_params.update(model_hijack.extra_generation_params)
# params.txt should be saved after scripts.process_batch, since the
# infotext could be modified by that callback
# Example: a wildcard processed by process_batch sets an extra model
@@ -1513,6 +1534,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.firstpass_steps, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data, total_steps)
self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.firstpass_steps, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data, total_steps)
self.apply_hijack_generation_params()
def setup_conds(self):
if self.is_hr_pass:
# if we are in hr pass right now, the call is being made from the refiner, and we don't need to setup firstpass cons or switch model