mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-02-21 15:23:58 +00:00
Update controlnet.py
This commit is contained in:
@@ -1005,40 +1005,20 @@ class Script(scripts.Script, metaclass=(
|
||||
self.detected_map = detected_maps
|
||||
self.post_processors = post_processors
|
||||
|
||||
def controlnet_hack(self, p):
|
||||
t = time.time()
|
||||
if getattr(shared.cmd_opts, 'controlnet_tracemalloc', False):
|
||||
tracemalloc.start()
|
||||
setattr(self, "malloc_begin", tracemalloc.take_snapshot())
|
||||
def process_unit_after_click_generate(self, p, unit, *args, **kwargs):
|
||||
return
|
||||
|
||||
self.controlnet_main_entry(p)
|
||||
if getattr(shared.cmd_opts, 'controlnet_tracemalloc', False):
|
||||
logger.info("After hook malloc:")
|
||||
for stat in tracemalloc.take_snapshot().compare_to(self.malloc_begin, "lineno")[:10]:
|
||||
logger.info(stat)
|
||||
|
||||
if len(self.enabled_units) > 0:
|
||||
logger.info(f'ControlNet Patched - Time = {time.time() - t}')
|
||||
|
||||
@staticmethod
|
||||
def process_has_sdxl_refiner(p):
|
||||
return getattr(p, 'refiner_checkpoint', None) is not None
|
||||
def process_unit_before_every_sampling(self, p, unit, *args, **kwargs):
|
||||
return
|
||||
|
||||
def process(self, p, *args, **kwargs):
|
||||
if not Script.process_has_sdxl_refiner(p):
|
||||
self.controlnet_hack(p)
|
||||
for unit in Script.get_enabled_units(p):
|
||||
self.process_unit_after_click_generate(p, unit, *args, **kwargs)
|
||||
return
|
||||
|
||||
def before_process_batch(self, p, *args, **kwargs):
|
||||
if Script.process_has_sdxl_refiner(p):
|
||||
self.controlnet_hack(p)
|
||||
return
|
||||
|
||||
def postprocess_batch(self, p, *args, **kwargs):
|
||||
images = kwargs.get('images', [])
|
||||
for post_processor in self.post_processors:
|
||||
for i in range(len(images)):
|
||||
images[i] = post_processor(images[i])
|
||||
def process_before_every_sampling(self, p, *args, **kwargs):
|
||||
for unit in Script.get_enabled_units(p):
|
||||
self.process_unit_before_every_sampling(p, unit, *args, **kwargs)
|
||||
return
|
||||
|
||||
def postprocess(self, p, processed, *args):
|
||||
|
||||
Reference in New Issue
Block a user