Files
stable-diffusion-webui-forge/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py
lllyasviel fb2e271668 support inpaint models from fooocus
put inpaint_v26.fooocus.patch in models\ControlNet, control SDXL models only
To get same algorithm as Fooocus, set "Stop at" (Ending Control Step) to 0.5
Fooocus always use 0.5 but in Forge users may use other values.
Results are best when stop at < 0.7. The model is not optimized with ending timesteps > 0.7
Supports inpaint_global_harmonious, inpaint_only, inpaint_only+lama.
In theory the inpaint_only+lama always outperform Fooocus in object removal task (but not all tasks).
2024-02-09 17:08:48 -08:00

571 lines
26 KiB
Python

import os
from typing import Dict, Optional, Tuple, List, Union
import cv2
import torch
import modules.scripts as scripts
from modules import shared, script_callbacks, masking, images
from modules.ui_components import InputAccordion
from modules.api.api import decode_base64_to_image
import gradio as gr
from lib_controlnet import global_state, external_code
from lib_controlnet.utils import align_dim_latent, image_dict_from_any, set_numpy_seed, crop_and_resize_image, \
prepare_mask, judge_image_type
from lib_controlnet.controlnet_ui.controlnet_ui_group import ControlNetUiGroup, UiControlNetUnit
from lib_controlnet.controlnet_ui.photopea import Photopea
from lib_controlnet.logging import logger
from modules.processing import StableDiffusionProcessingImg2Img, StableDiffusionProcessingTxt2Img, \
StableDiffusionProcessing
from lib_controlnet.infotext import Infotext
from modules_forge.forge_util import HWC3, numpy_to_pytorch
from lib_controlnet.enums import HiResFixOption
import numpy as np
import functools
from PIL import Image
from modules_forge.shared import try_load_supported_control_model
from modules_forge.supported_controlnet import ControlModelPatcher
# Gradio 3.32 bug fix
import tempfile
gradio_tempfile_path = os.path.join(tempfile.gettempdir(), 'gradio')
os.makedirs(gradio_tempfile_path, exist_ok=True)
global_state.update_controlnet_filenames()
@functools.lru_cache(maxsize=shared.opts.data.get("control_net_model_cache_size", 5))
def cached_controlnet_loader(filename):
return try_load_supported_control_model(filename)
class ControlNetCachedParameters:
def __init__(self):
self.preprocessor = None
self.model = None
self.control_cond = None
self.control_cond_for_hr_fix = None
self.control_mask = None
self.control_mask_for_hr_fix = None
class ControlNetForForgeOfficial(scripts.Script):
def title(self):
return "ControlNet"
def show(self, is_img2img):
return scripts.AlwaysVisible
def ui(self, is_img2img):
infotext = Infotext()
ui_groups = []
controls = []
max_models = shared.opts.data.get("control_net_unit_count", 3)
gen_type = "img2img" if is_img2img else "txt2img"
elem_id_tabname = gen_type + "_controlnet"
default_unit = UiControlNetUnit(enabled=False, module="None", model="None")
with gr.Group(elem_id=elem_id_tabname):
with gr.Accordion(f"ControlNet Integrated", open=False, elem_id="controlnet",
elem_classes=["controlnet"]):
photopea = (
Photopea()
if not shared.opts.data.get("controlnet_disable_photopea_edit", False)
else None
)
with gr.Row(elem_id=elem_id_tabname + "_accordions", elem_classes="accordions"):
for i in range(max_models):
with InputAccordion(
value=False,
label=f"ControlNet Unit {i}",
elem_classes=["cnet-unit-enabled-accordion"], # Class on accordion
):
group = ControlNetUiGroup(is_img2img, default_unit, photopea)
ui_groups.append(group)
controls.append(group.render(f"ControlNet-{i}", elem_id_tabname))
for i, ui_group in enumerate(ui_groups):
infotext.register_unit(i, ui_group)
if shared.opts.data.get("control_net_sync_field_args", True):
self.infotext_fields = infotext.infotext_fields
self.paste_field_names = infotext.paste_field_names
return tuple(controls)
def get_enabled_units(self, units):
enabled_units = [x for x in units if x.enabled]
return enabled_units
@staticmethod
def try_crop_image_with_a1111_mask(
p: StableDiffusionProcessing,
unit: external_code.ControlNetUnit,
input_image: np.ndarray,
resize_mode: external_code.ResizeMode,
preprocessor
) -> np.ndarray:
a1111_mask_image: Optional[Image.Image] = getattr(p, "image_mask", None)
is_only_masked_inpaint = (
issubclass(type(p), StableDiffusionProcessingImg2Img) and
p.inpaint_full_res and
a1111_mask_image is not None
)
if (
preprocessor.corp_image_with_a1111_mask_when_in_img2img_inpaint_tab
and is_only_masked_inpaint
):
logger.info("Crop input image based on A1111 mask.")
input_image = [input_image[:, :, i] for i in range(input_image.shape[2])]
input_image = [Image.fromarray(x) for x in input_image]
mask = prepare_mask(a1111_mask_image, p)
crop_region = masking.get_crop_region(np.array(mask), p.inpaint_full_res_padding)
crop_region = masking.expand_crop_region(crop_region, p.width, p.height, mask.width, mask.height)
input_image = [
images.resize_image(resize_mode.int_value(), i, mask.width, mask.height)
for i in input_image
]
input_image = [x.crop(crop_region) for x in input_image]
input_image = [
images.resize_image(external_code.ResizeMode.OUTER_FIT.int_value(), x, p.width, p.height)
for x in input_image
]
input_image = [np.asarray(x)[:, :, 0] for x in input_image]
input_image = np.stack(input_image, axis=2)
return input_image
def get_input_data(self, p, unit, preprocessor):
logger.info(f'ControlNet Input Mode: {unit.input_mode}')
resize_mode = external_code.resize_mode_from_value(unit.resize_mode)
if unit.input_mode == external_code.InputMode.MERGE:
image_list = []
for idx, item in enumerate(unit.batch_input_gallery):
img_path = item['name']
logger.info(f'Try to read image: {img_path}')
img = np.ascontiguousarray(cv2.imread(img_path)[:, :, ::-1]).copy()
mask = None
if len(unit.batch_mask_gallery) > 0:
if len(unit.batch_mask_gallery) >= len(unit.batch_input_gallery):
mask_path = unit.batch_mask_gallery[idx]['name']
else:
mask_path = unit.batch_mask_gallery[0]['name']
mask = np.ascontiguousarray(cv2.imread(mask_path)[:, :, ::-1]).copy()
if img is not None:
image_list.append([img, mask])
return image_list, resize_mode
if unit.input_mode == external_code.InputMode.BATCH:
image_list = []
image_extensions = ['.jpg', '.jpeg', '.png', '.bmp']
for idx, filename in enumerate(os.listdir(unit.batch_image_dir)):
if any(filename.lower().endswith(ext) for ext in image_extensions):
img_path = os.path.join(unit.batch_image_dir, filename)
logger.info(f'Try to read image: {img_path}')
img = np.ascontiguousarray(cv2.imread(img_path)[:, :, ::-1]).copy()
mask = None
if len(unit.batch_mask_dir) > 0:
if len(unit.batch_mask_dir) >= len(unit.batch_image_dir):
mask_path = unit.batch_mask_dir[idx]
else:
mask_path = unit.batch_mask_dir[0]
mask_path = os.path.join(unit.batch_mask_dir, mask_path)
mask = np.ascontiguousarray(cv2.imread(mask_path)[:, :, ::-1]).copy()
if img is not None:
image_list.append([img, mask])
return image_list, resize_mode
a1111_i2i_image = getattr(p, "init_images", [None])[0]
a1111_i2i_mask = getattr(p, "image_mask", None)
using_a1111_data = False
if unit.use_preview_as_input and unit.generated_image is not None:
image = unit.generated_image
elif unit.image is None:
resize_mode = external_code.resize_mode_from_value(p.resize_mode)
image = HWC3(np.asarray(a1111_i2i_image))
using_a1111_data = True
elif (unit.image['image'] < 5).all() and (unit.image['mask'] > 5).any():
image = unit.image['mask']
else:
image = unit.image['image']
if not isinstance(image, np.ndarray):
raise ValueError("controlnet is enabled but no input image is given")
image = HWC3(image)
if using_a1111_data:
mask = HWC3(np.asarray(a1111_i2i_mask)) if a1111_i2i_mask is not None else None
elif unit.mask_image is not None and (unit.mask_image['image'] > 5).any():
mask = unit.mask_image['image']
elif unit.mask_image is not None and (unit.mask_image['mask'] > 5).any():
mask = unit.mask_image['mask']
elif unit.image is not None and (unit.image['mask'] > 5).any():
mask = unit.image['mask']
else:
mask = None
image = self.try_crop_image_with_a1111_mask(p, unit, image, resize_mode, preprocessor)
if mask is not None:
mask = cv2.resize(HWC3(mask), (image.shape[1], image.shape[0]), interpolation=cv2.INTER_NEAREST)
mask = self.try_crop_image_with_a1111_mask(p, unit, mask, resize_mode, preprocessor)
return [[image, mask]], resize_mode
@staticmethod
def get_target_dimensions(p: StableDiffusionProcessing) -> Tuple[int, int, int, int]:
"""Returns (h, w, hr_h, hr_w)."""
h = align_dim_latent(p.height)
w = align_dim_latent(p.width)
high_res_fix = (
isinstance(p, StableDiffusionProcessingTxt2Img)
and getattr(p, 'enable_hr', False)
)
if high_res_fix:
if p.hr_resize_x == 0 and p.hr_resize_y == 0:
hr_y = int(p.height * p.hr_scale)
hr_x = int(p.width * p.hr_scale)
else:
hr_y, hr_x = p.hr_resize_y, p.hr_resize_x
hr_y = align_dim_latent(hr_y)
hr_x = align_dim_latent(hr_x)
else:
hr_y = h
hr_x = w
return h, w, hr_y, hr_x
@torch.no_grad()
def process_unit_after_click_generate(self,
p: StableDiffusionProcessing,
unit: external_code.ControlNetUnit,
params: ControlNetCachedParameters,
*args, **kwargs):
h, w, hr_y, hr_x = self.get_target_dimensions(p)
has_high_res_fix = (
isinstance(p, StableDiffusionProcessingTxt2Img)
and getattr(p, 'enable_hr', False)
)
if unit.use_preview_as_input:
unit.module = 'None'
preprocessor = global_state.get_preprocessor(unit.module)
input_list, resize_mode = self.get_input_data(p, unit, preprocessor)
preprocessor_outputs = []
control_masks = []
preprocessor_output_is_image = False
preprocessor_output = None
def optional_tqdm(iterable, use_tqdm):
from tqdm import tqdm
return tqdm(iterable) if use_tqdm else iterable
for input_image, input_mask in optional_tqdm(input_list, len(input_list) > 1):
# p.extra_result_images.append(input_image)
if unit.pixel_perfect:
unit.processor_res = external_code.pixel_perfect_resolution(
input_image,
target_H=h,
target_W=w,
resize_mode=resize_mode,
)
seed = set_numpy_seed(p)
logger.debug(f"Use numpy seed {seed}.")
logger.info(f"Using preprocessor: {unit.module}")
logger.info(f'preprocessor resolution = {unit.processor_res}')
preprocessor_output = preprocessor(
input_image=input_image,
input_mask=input_mask,
resolution=unit.processor_res,
slider_1=unit.threshold_a,
slider_2=unit.threshold_b,
)
preprocessor_outputs.append(preprocessor_output)
preprocessor_output_is_image = judge_image_type(preprocessor_output)
if input_mask is not None:
control_masks.append(input_mask)
if len(input_list) > 1 and not preprocessor_output_is_image:
logger.info('Batch wise input only support controlnet, control-lora, and t2i adapters!')
break
if has_high_res_fix:
hr_option = HiResFixOption.from_value(unit.hr_option)
else:
hr_option = HiResFixOption.BOTH
alignment_indices = [i % len(preprocessor_outputs) for i in range(p.batch_size)]
if preprocessor_output_is_image:
params.control_cond = []
params.control_cond_for_hr_fix = []
for preprocessor_output in preprocessor_outputs:
control_cond = crop_and_resize_image(preprocessor_output, resize_mode, h, w)
if hr_option.low_res_enabled:
p.extra_result_images.append(external_code.visualize_inpaint_mask(control_cond))
params.control_cond.append(numpy_to_pytorch(control_cond).movedim(-1, 1))
params.control_cond = torch.cat(params.control_cond, dim=0)[alignment_indices].contiguous()
if has_high_res_fix:
for preprocessor_output in preprocessor_outputs:
control_cond_for_hr_fix = crop_and_resize_image(preprocessor_output, resize_mode, hr_y, hr_x)
if hr_option.high_res_enabled:
p.extra_result_images.append(external_code.visualize_inpaint_mask(control_cond_for_hr_fix))
params.control_cond_for_hr_fix.append(numpy_to_pytorch(control_cond_for_hr_fix).movedim(-1, 1))
params.control_cond_for_hr_fix = torch.cat(params.control_cond_for_hr_fix, dim=0)[alignment_indices].contiguous()
else:
params.control_cond_for_hr_fix = params.control_cond
else:
params.control_cond = preprocessor_output
params.control_cond_for_hr_fix = preprocessor_output
p.extra_result_images.append(input_image)
if len(control_masks) > 0:
params.control_mask = []
params.control_mask_for_hr_fix = []
for input_mask in control_masks:
fill_border = preprocessor.fill_mask_with_one_when_resize_and_fill
control_mask = crop_and_resize_image(input_mask, resize_mode, h, w, fill_border)
if hr_option.low_res_enabled:
p.extra_result_images.append(control_mask)
control_mask = numpy_to_pytorch(control_mask).movedim(-1, 1)[:, :1]
params.control_mask.append(control_mask)
if has_high_res_fix:
control_mask_for_hr_fix = crop_and_resize_image(input_mask, resize_mode, hr_y, hr_x, fill_border)
if hr_option.high_res_enabled:
p.extra_result_images.append(control_mask_for_hr_fix)
control_mask_for_hr_fix = numpy_to_pytorch(control_mask_for_hr_fix).movedim(-1, 1)[:, :1]
params.control_mask_for_hr_fix.append(control_mask_for_hr_fix)
params.control_mask = torch.cat(params.control_mask, dim=0)[alignment_indices].contiguous()
if has_high_res_fix:
params.control_mask_for_hr_fix = torch.cat(params.control_mask_for_hr_fix, dim=0)[alignment_indices].contiguous()
else:
params.control_mask_for_hr_fix = params.control_mask
if preprocessor.do_not_need_model:
model_filename = 'Not Needed'
params.model = ControlModelPatcher()
else:
assert unit.model != 'None', 'You have not selected any control model!'
model_filename = global_state.get_controlnet_filename(unit.model)
params.model = cached_controlnet_loader(model_filename)
assert params.model is not None, logger.error(f"Recognizing Control Model failed: {model_filename}")
params.preprocessor = preprocessor
params.preprocessor.process_after_running_preprocessors(process=p, params=params, **kwargs)
params.model.process_after_running_preprocessors(process=p, params=params, **kwargs)
logger.info(f"Current ControlNet {type(params.model).__name__}: {model_filename}")
return
@torch.no_grad()
def process_unit_before_every_sampling(self,
p: StableDiffusionProcessing,
unit: external_code.ControlNetUnit,
params: ControlNetCachedParameters,
*args, **kwargs):
is_hr_pass = getattr(p, 'is_hr_pass', False)
has_high_res_fix = (
isinstance(p, StableDiffusionProcessingTxt2Img)
and getattr(p, 'enable_hr', False)
)
if has_high_res_fix:
hr_option = HiResFixOption.from_value(unit.hr_option)
else:
hr_option = HiResFixOption.BOTH
if has_high_res_fix and is_hr_pass and (not hr_option.high_res_enabled):
logger.info(f"ControlNet Skipped High-res pass.")
return
if has_high_res_fix and (not is_hr_pass) and (not hr_option.low_res_enabled):
logger.info(f"ControlNet Skipped Low-res pass.")
return
if is_hr_pass:
cond = params.control_cond_for_hr_fix
mask = params.control_mask_for_hr_fix
else:
cond = params.control_cond
mask = params.control_mask
kwargs.update(dict(unit=unit, params=params, cond_original=cond.clone(), mask_original=mask.clone()))
params.model.strength = float(unit.weight)
params.model.start_percent = float(unit.guidance_start)
params.model.end_percent = float(unit.guidance_end)
params.model.positive_advanced_weighting = None
params.model.negative_advanced_weighting = None
params.model.advanced_frame_weighting = None
params.model.advanced_sigma_weighting = None
soft_weighting = {
'input': [0.09941396206337118, 0.12050177219802567, 0.14606275417942507, 0.17704576264172736,
0.214600924414215,
0.26012233262329093, 0.3152997971191405, 0.3821815722656249, 0.4632503906249999, 0.561515625,
0.6806249999999999, 0.825],
'middle': [0.561515625] if p.sd_model.is_sdxl else [1.0],
'output': [0.09941396206337118, 0.12050177219802567, 0.14606275417942507, 0.17704576264172736,
0.214600924414215,
0.26012233262329093, 0.3152997971191405, 0.3821815722656249, 0.4632503906249999, 0.561515625,
0.6806249999999999, 0.825]
}
zero_weighting = {
'input': [0.0] * 12,
'middle': [0.0],
'output': [0.0] * 12
}
if unit.control_mode == external_code.ControlMode.CONTROL.value:
params.model.positive_advanced_weighting = soft_weighting.copy()
params.model.negative_advanced_weighting = zero_weighting.copy()
# high-ref fix pass always use softer injections
if is_hr_pass or unit.control_mode == external_code.ControlMode.PROMPT.value:
params.model.positive_advanced_weighting = soft_weighting.copy()
params.model.negative_advanced_weighting = soft_weighting.copy()
cond, mask = params.preprocessor.process_before_every_sampling(p, cond, mask, *args, **kwargs)
params.model.advanced_mask_weighting = mask
params.model.process_before_every_sampling(p, cond, mask, *args, **kwargs)
logger.info(f"ControlNet Method {params.preprocessor.name} patched.")
return
@staticmethod
def bound_check_params(unit: external_code.ControlNetUnit) -> None:
"""
Checks and corrects negative parameters in ControlNetUnit 'unit'.
Parameters 'processor_res', 'threshold_a', 'threshold_b' are reset to
their default values if negative.
Args:
unit (external_code.ControlNetUnit): The ControlNetUnit instance to check.
"""
preprocessor = global_state.get_preprocessor(unit.module)
if unit.processor_res < 0:
unit.processor_res = int(preprocessor.slider_resolution.gradio_update_kwargs.get('value', 512))
if unit.threshold_a < 0:
unit.threshold_a = int(preprocessor.slider_1.gradio_update_kwargs.get('value', 1.0))
if unit.threshold_b < 0:
unit.threshold_b = int(preprocessor.slider_2.gradio_update_kwargs.get('value', 1.0))
return
@torch.no_grad()
def process_unit_after_every_sampling(self,
p: StableDiffusionProcessing,
unit: external_code.ControlNetUnit,
params: ControlNetCachedParameters,
*args, **kwargs):
params.preprocessor.process_after_every_sampling(p, params, *args, **kwargs)
params.model.process_after_every_sampling(p, params, *args, **kwargs)
return
@torch.no_grad()
def process(self, p, *args, **kwargs):
self.current_params = {}
enabled_units = self.get_enabled_units(args)
Infotext.write_infotext(enabled_units, p)
for i, unit in enumerate(enabled_units):
self.bound_check_params(unit)
params = ControlNetCachedParameters()
self.process_unit_after_click_generate(p, unit, params, *args, **kwargs)
self.current_params[i] = params
return
@torch.no_grad()
def process_before_every_sampling(self, p, *args, **kwargs):
for i, unit in enumerate(self.get_enabled_units(args)):
self.process_unit_before_every_sampling(p, unit, self.current_params[i], *args, **kwargs)
return
@torch.no_grad()
def postprocess_batch_list(self, p, pp, *args, **kwargs):
for i, unit in enumerate(self.get_enabled_units(args)):
self.process_unit_after_every_sampling(p, unit, self.current_params[i], pp, *args, **kwargs)
return
def postprocess(self, p, processed, *args):
self.current_params = {}
return
def on_ui_settings():
section = ('control_net', "ControlNet")
shared.opts.add_option("control_net_detectedmap_dir", shared.OptionInfo(
"detected_maps", "Directory for detected maps auto saving", section=section))
shared.opts.add_option("control_net_models_path", shared.OptionInfo(
"", "Extra path to scan for ControlNet models (e.g. training output directory)", section=section))
shared.opts.add_option("control_net_modules_path", shared.OptionInfo(
"",
"Path to directory containing annotator model directories (requires restart, overrides corresponding command line flag)",
section=section))
shared.opts.add_option("control_net_unit_count", shared.OptionInfo(
3, "Multi-ControlNet: ControlNet unit number (requires restart)", gr.Slider,
{"minimum": 1, "maximum": 10, "step": 1}, section=section))
shared.opts.add_option("control_net_model_cache_size", shared.OptionInfo(
5, "Model cache size (requires restart)", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}, section=section))
shared.opts.add_option("control_net_no_detectmap", shared.OptionInfo(
False, "Do not append detectmap to output", gr.Checkbox, {"interactive": True}, section=section))
shared.opts.add_option("control_net_detectmap_autosaving", shared.OptionInfo(
False, "Allow detectmap auto saving", gr.Checkbox, {"interactive": True}, section=section))
shared.opts.add_option("control_net_allow_script_control", shared.OptionInfo(
False, "Allow other script to control this extension", gr.Checkbox, {"interactive": True}, section=section))
shared.opts.add_option("control_net_sync_field_args", shared.OptionInfo(
True, "Paste ControlNet parameters in infotext", gr.Checkbox, {"interactive": True}, section=section))
shared.opts.add_option("controlnet_show_batch_images_in_ui", shared.OptionInfo(
False, "Show batch images in gradio gallery output", gr.Checkbox, {"interactive": True}, section=section))
shared.opts.add_option("controlnet_increment_seed_during_batch", shared.OptionInfo(
False, "Increment seed after each controlnet batch iteration", gr.Checkbox, {"interactive": True},
section=section))
shared.opts.add_option("controlnet_disable_openpose_edit", shared.OptionInfo(
False, "Disable openpose edit", gr.Checkbox, {"interactive": True}, section=section))
shared.opts.add_option("controlnet_disable_photopea_edit", shared.OptionInfo(
False, "Disable photopea edit", gr.Checkbox, {"interactive": True}, section=section))
shared.opts.add_option("controlnet_photopea_warning", shared.OptionInfo(
True, "Photopea popup warning", gr.Checkbox, {"interactive": True}, section=section))
shared.opts.add_option("controlnet_input_thumbnail", shared.OptionInfo(
True, "Input image thumbnail on unit header", gr.Checkbox, {"interactive": True}, section=section))
script_callbacks.on_ui_settings(on_ui_settings)
script_callbacks.on_infotext_pasted(Infotext.on_infotext_pasted)
script_callbacks.on_after_component(ControlNetUiGroup.on_after_component)
script_callbacks.on_before_reload(ControlNetUiGroup.reset)