Update controlnet.py

This commit is contained in:
lllyasviel
2024-01-29 15:26:14 -08:00
parent 176007f7c6
commit 04a062c5e4

View File

@@ -580,19 +580,18 @@ class ControlNetForForgeOfficial(scripts.Script):
Args:
unit (external_code.ControlNetUnit): The ControlNetUnit instance to check.
"""
cfg = preprocessor_sliders_config.get(
global_state.get_module_basename(unit.module), [])
defaults = {
param: cfg_default['value']
for param, cfg_default in zip(
("processor_res", 'threshold_a', 'threshold_b'), cfg)
if cfg_default is not None
}
for param, default_value in defaults.items():
value = getattr(unit, param)
if value < 0:
setattr(unit, param, default_value)
logger.warning(f'[{unit.module}.{param}] Invalid value({value}), using default value {default_value}.')
preprocessor = global_state.get_preprocessor(unit.module)
if unit.processor_res < 0:
unit.processor_res = int(preprocessor.slider_resolution.gradio_update_kwargs.get('value', 512))
if unit.threshold_a < 0:
unit.threshold_a = int(preprocessor.slider_1.gradio_update_kwargs.get('value', 1.0))
if unit.threshold_b < 0:
unit.threshold_b = int(preprocessor.slider_2.gradio_update_kwargs.get('value', 1.0))
return
@staticmethod
def check_sd_version_compatible(unit: external_code.ControlNetUnit) -> None:
@@ -645,52 +644,6 @@ class ControlNetForForgeOfficial(scripts.Script):
return h, w, hr_y, hr_x
def controlnet_main_entry(self, p):
sd_ldm = p.sd_model
unet = sd_ldm.model.diffusion_model
self.noise_modifier = None
setattr(p, 'controlnet_control_loras', [])
if self.latest_network is not None:
# always restore (~0.05s)
self.latest_network.restore()
# always clear (~0.05s)
clear_all_secondary_control_models(unet)
if not batch_hijack.instance.is_batch:
self.enabled_units = Script.get_enabled_units(p)
batch_option_uint_separate = self.ui_batch_option_state[0] == external_code.BatchOption.SEPARATE.value
batch_option_style_align = self.ui_batch_option_state[1]
if len(self.enabled_units) == 0 and not batch_option_style_align:
self.latest_network = None
return
logger.info(f"unit_separate = {batch_option_uint_separate}, style_align = {batch_option_style_align}")
detected_maps = []
forward_params = []
post_processors = []
# cache stuff
if self.latest_model_hash != p.sd_model.sd_model_hash:
Script.clear_control_model_cache()
for idx, unit in enumerate(self.enabled_units):
unit.module = global_state.get_module_basename(unit.module)
# unload unused preproc
module_list = [unit.module for unit in self.enabled_units]
for key in self.unloadable:
if key not in module_list:
self.unloadable.get(key, lambda:None)()
self.latest_model_hash = p.sd_model.sd_model_hash
high_res_fix = isinstance(p, StableDiffusionProcessingTxt2Img) and getattr(p, 'enable_hr', False)
h, w, hr_y, hr_x = Script.get_target_dimensions(p)
for idx, unit in enumerate(self.enabled_units):
Script.bound_check_params(unit)
Script.check_sd_version_compatible(unit)
@@ -960,6 +913,7 @@ class ControlNetForForgeOfficial(scripts.Script):
def process(self, p, *args, **kwargs):
self.current_params = {}
for i, unit in enumerate(self.get_enabled_units(p)):
self.bound_check_params(unit)
params = ControlNetCachedParameters()
self.process_unit_after_click_generate(p, unit, params, *args, **kwargs)
self.current_params[i] = params