Update controlnet.py

This commit is contained in:
lllyasviel
2024-01-29 14:41:10 -08:00
parent 08763a234d
commit ec7adb41fa

View File

@@ -1005,40 +1005,20 @@ class Script(scripts.Script, metaclass=(
self.detected_map = detected_maps
self.post_processors = post_processors
def controlnet_hack(self, p):
t = time.time()
if getattr(shared.cmd_opts, 'controlnet_tracemalloc', False):
tracemalloc.start()
setattr(self, "malloc_begin", tracemalloc.take_snapshot())
def process_unit_after_click_generate(self, p, unit, *args, **kwargs):
return
self.controlnet_main_entry(p)
if getattr(shared.cmd_opts, 'controlnet_tracemalloc', False):
logger.info("After hook malloc:")
for stat in tracemalloc.take_snapshot().compare_to(self.malloc_begin, "lineno")[:10]:
logger.info(stat)
if len(self.enabled_units) > 0:
logger.info(f'ControlNet Patched - Time = {time.time() - t}')
@staticmethod
def process_has_sdxl_refiner(p):
return getattr(p, 'refiner_checkpoint', None) is not None
def process_unit_before_every_sampling(self, p, unit, *args, **kwargs):
return
def process(self, p, *args, **kwargs):
if not Script.process_has_sdxl_refiner(p):
self.controlnet_hack(p)
for unit in Script.get_enabled_units(p):
self.process_unit_after_click_generate(p, unit, *args, **kwargs)
return
def before_process_batch(self, p, *args, **kwargs):
if Script.process_has_sdxl_refiner(p):
self.controlnet_hack(p)
return
def postprocess_batch(self, p, *args, **kwargs):
images = kwargs.get('images', [])
for post_processor in self.post_processors:
for i in range(len(images)):
images[i] = post_processor(images[i])
def process_before_every_sampling(self, p, *args, **kwargs):
for unit in Script.get_enabled_units(p):
self.process_unit_before_every_sampling(p, unit, *args, **kwargs)
return
def postprocess(self, p, processed, *args):