From 04a062c5e492f5633a5dcc8ab34826ec3626ff28 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Mon, 29 Jan 2024 15:26:14 -0800 Subject: [PATCH] Update controlnet.py --- .../sd_forge_controlnet/scripts/controlnet.py | 72 ++++--------------- 1 file changed, 13 insertions(+), 59 deletions(-) diff --git a/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py b/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py index dbe1bf6e..91a18112 100644 --- a/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py +++ b/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py @@ -580,19 +580,18 @@ class ControlNetForForgeOfficial(scripts.Script): Args: unit (external_code.ControlNetUnit): The ControlNetUnit instance to check. """ - cfg = preprocessor_sliders_config.get( - global_state.get_module_basename(unit.module), []) - defaults = { - param: cfg_default['value'] - for param, cfg_default in zip( - ("processor_res", 'threshold_a', 'threshold_b'), cfg) - if cfg_default is not None - } - for param, default_value in defaults.items(): - value = getattr(unit, param) - if value < 0: - setattr(unit, param, default_value) - logger.warning(f'[{unit.module}.{param}] Invalid value({value}), using default value {default_value}.') + preprocessor = global_state.get_preprocessor(unit.module) + + if unit.processor_res < 0: + unit.processor_res = int(preprocessor.slider_resolution.gradio_update_kwargs.get('value', 512)) + + if unit.threshold_a < 0: + unit.threshold_a = int(preprocessor.slider_1.gradio_update_kwargs.get('value', 1.0)) + + if unit.threshold_b < 0: + unit.threshold_b = int(preprocessor.slider_2.gradio_update_kwargs.get('value', 1.0)) + + return @staticmethod def check_sd_version_compatible(unit: external_code.ControlNetUnit) -> None: @@ -645,52 +644,6 @@ class ControlNetForForgeOfficial(scripts.Script): return h, w, hr_y, hr_x def controlnet_main_entry(self, p): - sd_ldm = p.sd_model - unet = sd_ldm.model.diffusion_model - self.noise_modifier = None - - setattr(p, 'controlnet_control_loras', []) - - if self.latest_network is not None: - # always restore (~0.05s) - self.latest_network.restore() - - # always clear (~0.05s) - clear_all_secondary_control_models(unet) - - if not batch_hijack.instance.is_batch: - self.enabled_units = Script.get_enabled_units(p) - - batch_option_uint_separate = self.ui_batch_option_state[0] == external_code.BatchOption.SEPARATE.value - batch_option_style_align = self.ui_batch_option_state[1] - - if len(self.enabled_units) == 0 and not batch_option_style_align: - self.latest_network = None - return - - logger.info(f"unit_separate = {batch_option_uint_separate}, style_align = {batch_option_style_align}") - - detected_maps = [] - forward_params = [] - post_processors = [] - - # cache stuff - if self.latest_model_hash != p.sd_model.sd_model_hash: - Script.clear_control_model_cache() - - for idx, unit in enumerate(self.enabled_units): - unit.module = global_state.get_module_basename(unit.module) - - # unload unused preproc - module_list = [unit.module for unit in self.enabled_units] - for key in self.unloadable: - if key not in module_list: - self.unloadable.get(key, lambda:None)() - - self.latest_model_hash = p.sd_model.sd_model_hash - high_res_fix = isinstance(p, StableDiffusionProcessingTxt2Img) and getattr(p, 'enable_hr', False) - h, w, hr_y, hr_x = Script.get_target_dimensions(p) - for idx, unit in enumerate(self.enabled_units): Script.bound_check_params(unit) Script.check_sd_version_compatible(unit) @@ -960,6 +913,7 @@ class ControlNetForForgeOfficial(scripts.Script): def process(self, p, *args, **kwargs): self.current_params = {} for i, unit in enumerate(self.get_enabled_units(p)): + self.bound_check_params(unit) params = ControlNetCachedParameters() self.process_unit_after_click_generate(p, unit, params, *args, **kwargs) self.current_params[i] = params