Update controlnet.py

This commit is contained in:
lllyasviel
2024-01-29 15:29:56 -08:00
parent 04a062c5e4
commit d0e3124d62

View File

@@ -404,11 +404,10 @@ class ControlNetForForgeOfficial(scripts.Script):
Infotext.write_infotext(enabled_units, p)
return enabled_units
@staticmethod
def choose_input_image(
self,
p: processing.StableDiffusionProcessing,
unit: external_code.ControlNetUnit,
idx: int
) -> Tuple[np.ndarray, external_code.ResizeMode]:
""" Choose input image from following sources with descending priority:
- p.image_control: [Deprecated] Lagacy way to pass image to controlnet.
@@ -444,25 +443,12 @@ class ControlNetForForgeOfficial(scripts.Script):
return img
# 4 input image sources.
p_image_control = getattr(p, "image_control", None)
p_input_image = Script.get_remote_call(p, "control_net_input_image", None, idx)
image = parse_unit_image(unit)
a1111_image = getattr(p, "init_images", [None])[0]
resize_mode = external_code.resize_mode_from_value(unit.resize_mode)
if p_image_control is not None:
logger.warning("Warn: Using legacy field 'p.image_control'.")
input_image = HWC3(np.asarray(p_image_control))
elif p_input_image is not None:
logger.warning("Warn: Using legacy field 'p.controlnet_input_image'")
if isinstance(p_input_image, dict) and "mask" in p_input_image and "image" in p_input_image:
color = HWC3(np.asarray(p_input_image['image']))
alpha = np.asarray(p_input_image['mask'])[..., None]
input_image = np.concatenate([color, alpha], axis=2)
else:
input_image = HWC3(np.asarray(p_input_image))
elif image:
if image is not None:
if isinstance(image, list):
# Add mask logic if later there is a processor that accepts mask
# on multiple inputs.
@@ -645,41 +631,6 @@ class ControlNetForForgeOfficial(scripts.Script):
def controlnet_main_entry(self, p):
for idx, unit in enumerate(self.enabled_units):
Script.bound_check_params(unit)
Script.check_sd_version_compatible(unit)
if (
"ip-adapter" in unit.module and
not global_state.ip_adapter_pairing_model[unit.module](unit.model)
):
logger.error(f"Invalid pair of IP-Adapter preprocessor({unit.module}) and model({unit.model}).\n"
"Please follow following pairing logic:\n"
+ global_state.ip_adapter_pairing_logic_text)
continue
if (
'inpaint_only' == unit.module and
issubclass(type(p), StableDiffusionProcessingImg2Img) and
p.image_mask is not None
):
logger.warning('A1111 inpaint and ControlNet inpaint duplicated. Falls back to inpaint_global_harmonious.')
unit.module = 'inpaint'
if unit.module in model_free_preprocessors:
model_net = None
if 'reference' in unit.module:
control_model_type = ControlModelType.AttentionInjection
elif 'revision' in unit.module:
control_model_type = ControlModelType.ReVision
else:
raise Exception("Unable to determine control_model_type.")
else:
model_net, control_model_type = Script.load_control_model(p, unet, unit.model)
model_net.reset()
if control_model_type == ControlModelType.ControlLoRA:
control_lora = model_net.control_model
bind_control_lora(unet, control_lora)
p.controlnet_control_loras.append(control_lora)
input_image, resize_mode = Script.choose_input_image(p, unit, idx)
if isinstance(input_image, list):
@@ -903,6 +854,7 @@ class ControlNetForForgeOfficial(scripts.Script):
def process_unit_after_click_generate(self, p, unit, params, *args, **kwargs):
h, w, hr_y, hr_x = self.get_target_dimensions(p)
input_image, resize_mode = self.choose_input_image(p, unit)
return
def process_unit_before_every_sampling(self, p, unit, params, *args, **kwargs):