From ec7adb41faf814a052b7313900a8d2db7b127129 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Mon, 29 Jan 2024 14:41:10 -0800 Subject: [PATCH] Update controlnet.py --- .../sd_forge_controlnet/scripts/controlnet.py | 38 +++++-------------- 1 file changed, 9 insertions(+), 29 deletions(-) diff --git a/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py b/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py index eca42ff9..0b5a5367 100644 --- a/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py +++ b/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py @@ -1005,40 +1005,20 @@ class Script(scripts.Script, metaclass=( self.detected_map = detected_maps self.post_processors = post_processors - def controlnet_hack(self, p): - t = time.time() - if getattr(shared.cmd_opts, 'controlnet_tracemalloc', False): - tracemalloc.start() - setattr(self, "malloc_begin", tracemalloc.take_snapshot()) + def process_unit_after_click_generate(self, p, unit, *args, **kwargs): + return - self.controlnet_main_entry(p) - if getattr(shared.cmd_opts, 'controlnet_tracemalloc', False): - logger.info("After hook malloc:") - for stat in tracemalloc.take_snapshot().compare_to(self.malloc_begin, "lineno")[:10]: - logger.info(stat) - - if len(self.enabled_units) > 0: - logger.info(f'ControlNet Patched - Time = {time.time() - t}') - - @staticmethod - def process_has_sdxl_refiner(p): - return getattr(p, 'refiner_checkpoint', None) is not None + def process_unit_before_every_sampling(self, p, unit, *args, **kwargs): + return def process(self, p, *args, **kwargs): - if not Script.process_has_sdxl_refiner(p): - self.controlnet_hack(p) + for unit in Script.get_enabled_units(p): + self.process_unit_after_click_generate(p, unit, *args, **kwargs) return - def before_process_batch(self, p, *args, **kwargs): - if Script.process_has_sdxl_refiner(p): - self.controlnet_hack(p) - return - - def postprocess_batch(self, p, *args, **kwargs): - images = kwargs.get('images', []) - for post_processor in self.post_processors: - for i in range(len(images)): - images[i] = post_processor(images[i]) + def process_before_every_sampling(self, p, *args, **kwargs): + for unit in Script.get_enabled_units(p): + self.process_unit_before_every_sampling(p, unit, *args, **kwargs) return def postprocess(self, p, processed, *args):