mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-27 17:51:22 +00:00
ini mask support
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
from typing import Dict, Optional, Tuple, List, Union
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
|
||||
import modules.scripts as scripts
|
||||
@@ -46,6 +47,8 @@ class ControlNetCachedParameters:
|
||||
self.model = None
|
||||
self.control_cond = None
|
||||
self.control_cond_for_hr_fix = None
|
||||
self.control_mask = None
|
||||
self.control_mask_for_hr_fix = None
|
||||
|
||||
|
||||
class ControlNetForForgeOfficial(scripts.Script):
|
||||
@@ -95,105 +98,6 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
enabled_units = [x for x in units if x.enabled]
|
||||
return enabled_units
|
||||
|
||||
def choose_input_image(
|
||||
self,
|
||||
p: processing.StableDiffusionProcessing,
|
||||
unit: external_code.ControlNetUnit,
|
||||
) -> Tuple[np.ndarray, external_code.ResizeMode]:
|
||||
""" Choose input image from following sources with descending priority:
|
||||
- p.image_control: [Deprecated] Lagacy way to pass image to controlnet.
|
||||
- p.control_net_input_image: [Deprecated] Lagacy way to pass image to controlnet.
|
||||
- unit.image: ControlNet tab input image.
|
||||
- p.init_images: A1111 img2img tab input image.
|
||||
|
||||
Returns:
|
||||
- The input image in ndarray form.
|
||||
- The resize mode.
|
||||
"""
|
||||
|
||||
def parse_unit_image(unit: external_code.ControlNetUnit) -> Union[
|
||||
List[Dict[str, np.ndarray]], Dict[str, np.ndarray]]:
|
||||
unit_has_multiple_images = (
|
||||
isinstance(unit.image, list) and
|
||||
len(unit.image) > 0 and
|
||||
"image" in unit.image[0]
|
||||
)
|
||||
if unit_has_multiple_images:
|
||||
return [
|
||||
d
|
||||
for img in unit.image
|
||||
for d in (image_dict_from_any(img),)
|
||||
if d is not None
|
||||
]
|
||||
return image_dict_from_any(unit.image)
|
||||
|
||||
def decode_image(img) -> np.ndarray:
|
||||
"""Need to check the image for API compatibility."""
|
||||
if isinstance(img, str):
|
||||
return np.asarray(decode_base64_to_image(image['image']))
|
||||
else:
|
||||
assert isinstance(img, np.ndarray)
|
||||
return img
|
||||
|
||||
# 4 input image sources.
|
||||
image = parse_unit_image(unit)
|
||||
a1111_image = getattr(p, "init_images", [None])[0]
|
||||
|
||||
resize_mode = external_code.resize_mode_from_value(unit.resize_mode)
|
||||
|
||||
if image is not None:
|
||||
if isinstance(image, list):
|
||||
# Add mask logic if later there is a processor that accepts mask
|
||||
# on multiple inputs.
|
||||
input_image = [HWC3(decode_image(img['image'])) for img in image]
|
||||
else:
|
||||
input_image = HWC3(decode_image(image['image']))
|
||||
if 'mask' in image and image['mask'] is not None:
|
||||
while len(image['mask'].shape) < 3:
|
||||
image['mask'] = image['mask'][..., np.newaxis]
|
||||
if 'inpaint' in unit.module:
|
||||
logger.info("using inpaint as input")
|
||||
color = HWC3(image['image'])
|
||||
alpha = image['mask'][:, :, 0:1]
|
||||
input_image = np.concatenate([color, alpha], axis=2)
|
||||
elif (
|
||||
not shared.opts.data.get("controlnet_ignore_noninpaint_mask", False) and
|
||||
# There is wield gradio issue that would produce mask that is
|
||||
# not pure color when no scribble is made on canvas.
|
||||
# See https://github.com/Mikubill/sd-webui-controlnet/issues/1638.
|
||||
not (
|
||||
(image['mask'][:, :, 0] <= 5).all() or
|
||||
(image['mask'][:, :, 0] >= 250).all()
|
||||
)
|
||||
):
|
||||
logger.info("using mask as input")
|
||||
input_image = HWC3(image['mask'][:, :, 0])
|
||||
unit.module = 'none' # Always use black bg and white line
|
||||
elif a1111_image is not None:
|
||||
input_image = HWC3(np.asarray(a1111_image))
|
||||
a1111_i2i_resize_mode = getattr(p, "resize_mode", None)
|
||||
assert a1111_i2i_resize_mode is not None
|
||||
resize_mode = external_code.resize_mode_from_value(a1111_i2i_resize_mode)
|
||||
|
||||
a1111_mask_image: Optional[Image.Image] = getattr(p, "image_mask", None)
|
||||
if 'inpaint' in unit.module:
|
||||
if a1111_mask_image is not None:
|
||||
a1111_mask = np.array(prepare_mask(a1111_mask_image, p))
|
||||
assert a1111_mask.ndim == 2
|
||||
assert a1111_mask.shape[0] == input_image.shape[0]
|
||||
assert a1111_mask.shape[1] == input_image.shape[1]
|
||||
input_image = np.concatenate([input_image[:, :, 0:3], a1111_mask[:, :, None]], axis=2)
|
||||
else:
|
||||
input_image = np.concatenate([
|
||||
input_image[:, :, 0:3],
|
||||
np.zeros_like(input_image, dtype=np.uint8)[:, :, 0:1],
|
||||
], axis=2)
|
||||
else:
|
||||
raise ValueError("controlnet is enabled but no input image is given")
|
||||
|
||||
assert isinstance(input_image, (np.ndarray, list))
|
||||
return input_image, resize_mode
|
||||
|
||||
@staticmethod
|
||||
def try_crop_image_with_a1111_mask(
|
||||
p: StableDiffusionProcessing,
|
||||
@@ -202,19 +106,6 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
resize_mode: external_code.ResizeMode,
|
||||
preprocessor
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Crop ControlNet input image based on A1111 inpaint mask given.
|
||||
This logic is crutial in upscale scripts, as they use A1111 mask + inpaint_full_res
|
||||
to crop tiles.
|
||||
"""
|
||||
# Note: The method determining whether the active script is an upscale script is purely
|
||||
# based on `extra_generation_params` these scripts attach on `p`, and subject to change
|
||||
# in the future.
|
||||
# TODO: Change this to a more robust condition once A1111 offers a way to verify script name.
|
||||
is_upscale_script = any("upscale" in k.lower() for k in getattr(p, "extra_generation_params", {}).keys())
|
||||
logger.debug(f"is_upscale_script={is_upscale_script}")
|
||||
# Note: `inpaint_full_res` is "inpaint area" on UI. The flag is `True` when "Only masked"
|
||||
# option is selected.
|
||||
a1111_mask_image: Optional[Image.Image] = getattr(p, "image_mask", None)
|
||||
is_only_masked_inpaint = (
|
||||
issubclass(type(p), StableDiffusionProcessingImg2Img) and
|
||||
@@ -224,9 +115,8 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
if (
|
||||
preprocessor.corp_image_with_a1111_mask_when_in_img2img_inpaint_tab
|
||||
and is_only_masked_inpaint
|
||||
and (is_upscale_script or unit.inpaint_crop_input_image)
|
||||
):
|
||||
logger.debug("Crop input image based on A1111 mask.")
|
||||
logger.info("Crop input image based on A1111 mask.")
|
||||
input_image = [input_image[:, :, i] for i in range(input_image.shape[2])]
|
||||
input_image = [Image.fromarray(x) for x in input_image]
|
||||
|
||||
@@ -251,12 +141,47 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
return input_image
|
||||
|
||||
def get_input_data(self, p, unit, preprocessor):
|
||||
mask = None
|
||||
input_image, resize_mode = self.choose_input_image(p, unit)
|
||||
assert isinstance(input_image, np.ndarray), 'Invalid input image!'
|
||||
input_image = self.try_crop_image_with_a1111_mask(p, unit, input_image, resize_mode, preprocessor)
|
||||
input_image = np.ascontiguousarray(input_image.copy()).copy() # safe numpy
|
||||
return input_image, mask, resize_mode
|
||||
a1111_i2i_image = getattr(p, "init_images", [None])[0]
|
||||
a1111_i2i_mask = getattr(p, "image_mask", None)
|
||||
|
||||
using_a1111_data = False
|
||||
|
||||
resize_mode = external_code.resize_mode_from_value(unit.resize_mode)
|
||||
|
||||
if unit.use_preview_as_input and unit.generated_image is not None:
|
||||
image = unit.generated_image
|
||||
elif unit.image is None:
|
||||
resize_mode = external_code.resize_mode_from_value(p.resize_mode)
|
||||
image = HWC3(np.asarray(a1111_i2i_image))
|
||||
using_a1111_data = True
|
||||
elif (unit.image['image'] < 5).all() and (unit.image['mask'] > 5).any():
|
||||
image = unit.image['mask']
|
||||
else:
|
||||
image = unit.image['image']
|
||||
|
||||
if not isinstance(image, np.ndarray):
|
||||
raise ValueError("controlnet is enabled but no input image is given")
|
||||
|
||||
image = HWC3(image)
|
||||
|
||||
if using_a1111_data:
|
||||
mask = HWC3(np.asarray(a1111_i2i_mask))
|
||||
elif unit.mask_image is not None and (unit.mask_image['image'] > 5).any():
|
||||
mask = unit.mask_image['image']
|
||||
elif unit.mask_image is not None and (unit.mask_image['mask'] > 5).any():
|
||||
mask = unit.mask_image['mask']
|
||||
elif unit.image is not None and (unit.image['mask'] > 5).any():
|
||||
mask = unit.image['mask']
|
||||
else:
|
||||
mask = None
|
||||
|
||||
image = self.try_crop_image_with_a1111_mask(p, unit, image, resize_mode, preprocessor)
|
||||
|
||||
if mask is not None:
|
||||
mask = cv2.resize(HWC3(mask), (image.shape[1], image.shape[0]), interpolation=cv2.INTER_NEAREST)
|
||||
mask = self.try_crop_image_with_a1111_mask(p, unit, mask, resize_mode, preprocessor)
|
||||
|
||||
return image, mask, resize_mode
|
||||
|
||||
@staticmethod
|
||||
def get_target_dimensions(p: StableDiffusionProcessing) -> Tuple[int, int, int, int]:
|
||||
@@ -296,9 +221,13 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
and getattr(p, 'enable_hr', False)
|
||||
)
|
||||
|
||||
if unit.use_preview_as_input:
|
||||
unit.module = 'None'
|
||||
|
||||
preprocessor = global_state.get_preprocessor(unit.module)
|
||||
|
||||
input_image, input_mask, resize_mode = self.get_input_data(p, unit, preprocessor)
|
||||
# p.extra_result_images.append(input_image)
|
||||
|
||||
if unit.pixel_perfect:
|
||||
unit.processor_res = external_code.pixel_perfect_resolution(
|
||||
@@ -339,10 +268,24 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
params.control_cond_for_hr_fix = preprocessor_output
|
||||
p.extra_result_images.append(input_image)
|
||||
|
||||
if input_mask is not None:
|
||||
fill_border = preprocessor.fill_mask_with_one_when_resize_and_fill
|
||||
params.control_mask = crop_and_resize_image(input_mask, resize_mode, h, w, fill_border)
|
||||
p.extra_result_images.append(params.control_mask)
|
||||
params.control_mask = numpy_to_pytorch(params.control_mask).movedim(-1, 1)[:, :1]
|
||||
|
||||
if has_high_res_fix:
|
||||
params.control_mask_for_hr_fix = crop_and_resize_image(input_mask, resize_mode, hr_y, hr_x, fill_border)
|
||||
p.extra_result_images.append(params.control_mask_for_hr_fix)
|
||||
params.control_mask_for_hr_fix = numpy_to_pytorch(params.control_mask_for_hr_fix).movedim(-1, 1)[:, :1]
|
||||
else:
|
||||
params.control_mask_for_hr_fix = params.control_mask
|
||||
|
||||
if preprocessor.do_not_need_model:
|
||||
model_filename = 'Not Needed'
|
||||
params.model = ControlModelPatcher()
|
||||
else:
|
||||
assert unit.model != 'None', 'You have not selected any control model!'
|
||||
model_filename = global_state.get_controlnet_filename(unit.model)
|
||||
params.model = cached_controlnet_loader(model_filename)
|
||||
assert params.model is not None, logger.error(f"Recognizing Control Model failed: {model_filename}")
|
||||
@@ -366,16 +309,13 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
|
||||
if is_hr_pass:
|
||||
cond = params.control_cond_for_hr_fix
|
||||
mask = params.control_mask_for_hr_fix
|
||||
else:
|
||||
cond = params.control_cond
|
||||
mask = params.control_mask
|
||||
|
||||
kwargs.update(dict(unit=unit, params=params))
|
||||
|
||||
# CN inpaint fix
|
||||
if isinstance(cond, torch.Tensor) and cond.ndim == 4 and cond.shape[1] == 4:
|
||||
kwargs['cond_before_inpaint_fix'] = cond.clone()
|
||||
cond = cond[:, :3] * (1.0 - cond[:, 3:]) - cond[:, 3:]
|
||||
|
||||
params.model.strength = float(unit.weight)
|
||||
params.model.start_percent = float(unit.guidance_start)
|
||||
params.model.end_percent = float(unit.guidance_end)
|
||||
@@ -411,8 +351,8 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
params.model.positive_advanced_weighting = soft_weighting.copy()
|
||||
params.model.negative_advanced_weighting = soft_weighting.copy()
|
||||
|
||||
params.preprocessor.process_before_every_sampling(p, cond, *args, **kwargs)
|
||||
params.model.process_before_every_sampling(p, cond, *args, **kwargs)
|
||||
cond, mask = params.preprocessor.process_before_every_sampling(p, cond, mask, *args, **kwargs)
|
||||
params.model.process_before_every_sampling(p, cond, mask, *args, **kwargs)
|
||||
|
||||
logger.info(f"ControlNet Method {params.preprocessor.name} patched.")
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user