Add back ControlNet HR option (#90)

* Add back ControlNet HR option

* nits

* protect kernel

* add enum

* fix

* fix

* Update controlnet.py

* restore controlnet.py

* hint ui

* Update controlnet.py

* fix

* Update controlnet.py

* Backend: better controlnet mask batch broadcasting

* Update README.md

* fix inpaint batch dim align #94

* fix sigmas device in rare cases #71

* rework sigma device mapping

* Add hr_option to infotext

---------

Co-authored-by: lllyasviel <lyuminzhang@outlook.com>
This commit is contained in:
Chenlei Hu
2024-02-07 16:09:52 +00:00
committed by GitHub
parent 257ac2653a
commit e1faf8327b
4 changed files with 79 additions and 38 deletions

View File

@@ -16,7 +16,7 @@ from lib_controlnet.controlnet_ui.openpose_editor import OpenposeEditor
from lib_controlnet.controlnet_ui.preset import ControlNetPresetUI
from lib_controlnet.controlnet_ui.tool_button import ToolButton
from lib_controlnet.controlnet_ui.photopea import Photopea
from lib_controlnet.enums import InputMode
from lib_controlnet.enums import InputMode, HiResFixOption
from modules import shared
from modules.ui_components import FormRow
from modules_forge.forge_util import HWC3
@@ -46,7 +46,6 @@ class A1111Context:
img2img_inpaint_upload_tab: Optional[gr.components.IOComponent] = None
img2img_inpaint_area: Optional[gr.components.IOComponent] = None
# txt2img_enable_hr is only available for A1111 > 1.7.0.
txt2img_enable_hr: Optional[gr.components.IOComponent] = None
setting_sd_model_checkpoint: Optional[gr.components.IOComponent] = None
@@ -580,15 +579,15 @@ class ControlNetUiGroup(object):
visible=not self.is_img2img,
)
# self.hr_option = gr.Radio(
# choices=[e.value for e in external_code.HiResFixOption],
# value=self.default_unit.hr_option.value,
# label="Hires-Fix Option",
# elem_id=f"{elem_id_tabname}_{tabname}_controlnet_hr_option_radio",
# elem_classes="controlnet_hr_option_radio",
# visible=False,
# )
#
self.hr_option = gr.Radio(
choices=[e.value for e in HiResFixOption],
value=self.default_unit.hr_option.value,
label="Hires-Fix Option",
elem_id=f"{elem_id_tabname}_{tabname}_controlnet_hr_option_radio",
elem_classes="controlnet_hr_option_radio",
visible=False,
)
# self.loopback = gr.Checkbox(
# label="[Batch Loopback] Automatically send generated images to this ControlNet unit in batch generation",
# value=self.default_unit.loopback,
@@ -612,6 +611,7 @@ class ControlNetUiGroup(object):
self.batch_mask_gallery,
self.generated_image,
self.mask_image,
self.hr_option,
self.enabled,
self.module,
self.model,
@@ -978,7 +978,12 @@ class ControlNetUiGroup(object):
return
def register_shift_hr_options(self):
return
ControlNetUiGroup.a1111_context.txt2img_enable_hr.change(
fn=lambda checked: gr.update(visible=checked),
inputs=[ControlNetUiGroup.a1111_context.txt2img_enable_hr],
outputs=[self.hr_option],
show_progress=False,
)
def register_shift_upload_mask(self):
"""Controls whether the upload mask input should be visible."""

View File

@@ -1,5 +1,31 @@
from enum import Enum
from typing import Any
class HiResFixOption(Enum):
BOTH = "Both"
LOW_RES_ONLY = "Low res only"
HIGH_RES_ONLY = "High res only"
@staticmethod
def from_value(value) -> "HiResFixOption":
if isinstance(value, str) and value.startswith("HiResFixOption."):
_, field = value.split(".")
return getattr(HiResFixOption, field)
if isinstance(value, str):
return HiResFixOption(value)
elif isinstance(value, int):
return [x for x in HiResFixOption][value]
else:
assert isinstance(value, HiResFixOption)
return value
@property
def low_res_enabled(self) -> bool:
return self in (HiResFixOption.BOTH, HiResFixOption.LOW_RES_ONLY)
@property
def high_res_enabled(self) -> bool:
return self in (HiResFixOption.BOTH, HiResFixOption.HIGH_RES_ONLY)
class StableDiffusionVersion(Enum):
@@ -43,25 +69,6 @@ class StableDiffusionVersion(Enum):
)
class HiResFixOption(Enum):
BOTH = "Both"
LOW_RES_ONLY = "Low res only"
HIGH_RES_ONLY = "High res only"
@staticmethod
def from_value(value: Any) -> "HiResFixOption":
if isinstance(value, str) and value.startswith("HiResFixOption."):
_, field = value.split(".")
return getattr(HiResFixOption, field)
if isinstance(value, str):
return HiResFixOption(value)
elif isinstance(value, int):
return [x for x in HiResFixOption][value]
else:
assert isinstance(value, HiResFixOption)
return value
class InputMode(Enum):
# Single image to a single ControlNet unit.
SIMPLE = "simple"

View File

@@ -4,8 +4,7 @@ from typing import List, Optional, Union, Tuple, Dict
import numpy as np
from modules import shared
from lib_controlnet.logging import logger
from lib_controlnet.enums import InputMode
from lib_controlnet.enums import InputMode, HiResFixOption
from modules.api import api
@@ -156,6 +155,7 @@ class UiControlNetUnit:
batch_mask_gallery: list = [],
generated_image: Optional[np.ndarray] = None,
mask_image: Optional[np.ndarray] = None,
hr_option: Union[HiResFixOption, int, str] = HiResFixOption.BOTH
enabled: bool = True
module: str = "None"
model: str = "None"
@@ -189,6 +189,7 @@ class UiControlNetUnit:
"guidance_end",
"pixel_perfect",
"control_mode",
"hr_option",
)

View File

@@ -20,6 +20,7 @@ from modules.processing import StableDiffusionProcessingImg2Img, StableDiffusion
StableDiffusionProcessing
from lib_controlnet.infotext import Infotext
from modules_forge.forge_util import HWC3, numpy_to_pytorch
from lib_controlnet.enums import HiResFixOption
import numpy as np
import functools
@@ -308,6 +309,11 @@ class ControlNetForForgeOfficial(scripts.Script):
logger.info('Batch wise input only support controlnet, control-lora, and t2i adapters!')
break
if has_high_res_fix:
hr_option = HiResFixOption.from_value(unit.hr_option)
else:
hr_option = HiResFixOption.BOTH
alignment_indices = [i % len(preprocessor_outputs) for i in range(p.batch_size)]
if preprocessor_output_is_image:
params.control_cond = []
@@ -315,7 +321,8 @@ class ControlNetForForgeOfficial(scripts.Script):
for preprocessor_output in preprocessor_outputs:
control_cond = crop_and_resize_image(preprocessor_output, resize_mode, h, w)
p.extra_result_images.append(external_code.visualize_inpaint_mask(control_cond))
if hr_option.low_res_enabled:
p.extra_result_images.append(external_code.visualize_inpaint_mask(control_cond))
params.control_cond.append(numpy_to_pytorch(control_cond).movedim(-1, 1))
params.control_cond = torch.cat(params.control_cond, dim=0)[alignment_indices].contiguous()
@@ -323,7 +330,8 @@ class ControlNetForForgeOfficial(scripts.Script):
if has_high_res_fix:
for preprocessor_output in preprocessor_outputs:
control_cond_for_hr_fix = crop_and_resize_image(preprocessor_output, resize_mode, hr_y, hr_x)
p.extra_result_images.append(external_code.visualize_inpaint_mask(control_cond_for_hr_fix))
if hr_option.high_res_enabled:
p.extra_result_images.append(external_code.visualize_inpaint_mask(control_cond_for_hr_fix))
params.control_cond_for_hr_fix.append(numpy_to_pytorch(control_cond_for_hr_fix).movedim(-1, 1))
params.control_cond_for_hr_fix = torch.cat(params.control_cond_for_hr_fix, dim=0)[alignment_indices].contiguous()
else:
@@ -340,13 +348,15 @@ class ControlNetForForgeOfficial(scripts.Script):
for input_mask in control_masks:
fill_border = preprocessor.fill_mask_with_one_when_resize_and_fill
control_mask = crop_and_resize_image(input_mask, resize_mode, h, w, fill_border)
p.extra_result_images.append(control_mask)
if hr_option.low_res_enabled:
p.extra_result_images.append(control_mask)
control_mask = numpy_to_pytorch(control_mask).movedim(-1, 1)[:, :1]
params.control_mask.append(control_mask)
if has_high_res_fix:
control_mask_for_hr_fix = crop_and_resize_image(input_mask, resize_mode, hr_y, hr_x, fill_border)
p.extra_result_images.append(control_mask_for_hr_fix)
if hr_option.high_res_enabled:
p.extra_result_images.append(control_mask_for_hr_fix)
control_mask_for_hr_fix = numpy_to_pytorch(control_mask_for_hr_fix).movedim(-1, 1)[:, :1]
params.control_mask_for_hr_fix.append(control_mask_for_hr_fix)
@@ -382,6 +392,24 @@ class ControlNetForForgeOfficial(scripts.Script):
is_hr_pass = getattr(p, 'is_hr_pass', False)
has_high_res_fix = (
isinstance(p, StableDiffusionProcessingTxt2Img)
and getattr(p, 'enable_hr', False)
)
if has_high_res_fix:
hr_option = HiResFixOption.from_value(unit.hr_option)
else:
hr_option = HiResFixOption.BOTH
if has_high_res_fix and is_hr_pass and (not hr_option.high_res_enabled):
logger.info(f"ControlNet Skipped High-res pass.")
return
if has_high_res_fix and (not is_hr_pass) and (not hr_option.low_res_enabled):
logger.info(f"ControlNet Skipped Low-res pass.")
return
if is_hr_pass:
cond = params.control_cond_for_hr_fix
mask = params.control_mask_for_hr_fix