Update controlnet.py

This commit is contained in:
lllyasviel
2024-01-29 14:32:18 -08:00
parent ac374e0b97
commit 08763a234d

View File

@@ -1042,66 +1042,7 @@ class Script(scripts.Script, metaclass=(
return
def postprocess(self, p, processed, *args):
sd_ldm = p.sd_model
unet = sd_ldm.model.diffusion_model
clear_all_secondary_control_models(unet)
self.noise_modifier = None
for control_lora in getattr(p, 'controlnet_control_loras', []):
unbind_control_lora(control_lora)
p.controlnet_control_loras = []
self.post_processors = []
setattr(p, 'controlnet_vae_cache', None)
processor_params_flag = (', '.join(getattr(processed, 'extra_generation_params', []))).lower()
self.post_processors = []
if not batch_hijack.instance.is_batch:
self.enabled_units.clear()
if shared.opts.data.get("control_net_detectmap_autosaving", False) and self.latest_network is not None:
for detect_map, module in self.detected_map:
detectmap_dir = os.path.join(shared.opts.data.get("control_net_detectedmap_dir", ""), module)
if not os.path.isabs(detectmap_dir):
detectmap_dir = os.path.join(p.outpath_samples, detectmap_dir)
if module != "none":
os.makedirs(detectmap_dir, exist_ok=True)
img = Image.fromarray(np.ascontiguousarray(detect_map.clip(0, 255).astype(np.uint8)).copy())
save_image(img, detectmap_dir, module)
if self.latest_network is None:
return
if not batch_hijack.instance.is_batch:
if not shared.opts.data.get("control_net_no_detectmap", False):
if 'sd upscale' not in processor_params_flag:
if self.detected_map is not None:
for detect_map, module in self.detected_map:
if detect_map is None:
continue
detect_map = np.ascontiguousarray(detect_map.copy()).copy()
detect_map = external_code.visualize_inpaint_mask(detect_map)
processed.images.extend([
Image.fromarray(
detect_map.clip(0, 255).astype(np.uint8)
)
])
self.input_image = None
self.latest_network.restore()
self.latest_network = None
self.detected_map.clear()
gc.collect()
devices.torch_gc()
if getattr(shared.cmd_opts, 'controlnet_tracemalloc', False):
logger.info("After generation:")
for stat in tracemalloc.take_snapshot().compare_to(self.malloc_begin, "lineno")[:10]:
logger.info(stat)
tracemalloc.stop()
return
def batch_tab_process(self, p, batches, *args, **kwargs):
self.enabled_units = Script.get_enabled_units(p)