mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-01-26 19:09:45 +00:00
Add back ControlNet HR option (#90)
* Add back ControlNet HR option * nits * protect kernel * add enum * fix * fix * Update controlnet.py * restore controlnet.py * hint ui * Update controlnet.py * fix * Update controlnet.py * Backend: better controlnet mask batch broadcasting * Update README.md * fix inpaint batch dim align #94 * fix sigmas device in rare cases #71 * rework sigma device mapping * Add hr_option to infotext --------- Co-authored-by: lllyasviel <lyuminzhang@outlook.com>
This commit is contained in:
@@ -16,7 +16,7 @@ from lib_controlnet.controlnet_ui.openpose_editor import OpenposeEditor
|
||||
from lib_controlnet.controlnet_ui.preset import ControlNetPresetUI
|
||||
from lib_controlnet.controlnet_ui.tool_button import ToolButton
|
||||
from lib_controlnet.controlnet_ui.photopea import Photopea
|
||||
from lib_controlnet.enums import InputMode
|
||||
from lib_controlnet.enums import InputMode, HiResFixOption
|
||||
from modules import shared
|
||||
from modules.ui_components import FormRow
|
||||
from modules_forge.forge_util import HWC3
|
||||
@@ -46,7 +46,6 @@ class A1111Context:
|
||||
img2img_inpaint_upload_tab: Optional[gr.components.IOComponent] = None
|
||||
|
||||
img2img_inpaint_area: Optional[gr.components.IOComponent] = None
|
||||
# txt2img_enable_hr is only available for A1111 > 1.7.0.
|
||||
txt2img_enable_hr: Optional[gr.components.IOComponent] = None
|
||||
setting_sd_model_checkpoint: Optional[gr.components.IOComponent] = None
|
||||
|
||||
@@ -580,15 +579,15 @@ class ControlNetUiGroup(object):
|
||||
visible=not self.is_img2img,
|
||||
)
|
||||
|
||||
# self.hr_option = gr.Radio(
|
||||
# choices=[e.value for e in external_code.HiResFixOption],
|
||||
# value=self.default_unit.hr_option.value,
|
||||
# label="Hires-Fix Option",
|
||||
# elem_id=f"{elem_id_tabname}_{tabname}_controlnet_hr_option_radio",
|
||||
# elem_classes="controlnet_hr_option_radio",
|
||||
# visible=False,
|
||||
# )
|
||||
#
|
||||
self.hr_option = gr.Radio(
|
||||
choices=[e.value for e in HiResFixOption],
|
||||
value=self.default_unit.hr_option.value,
|
||||
label="Hires-Fix Option",
|
||||
elem_id=f"{elem_id_tabname}_{tabname}_controlnet_hr_option_radio",
|
||||
elem_classes="controlnet_hr_option_radio",
|
||||
visible=False,
|
||||
)
|
||||
|
||||
# self.loopback = gr.Checkbox(
|
||||
# label="[Batch Loopback] Automatically send generated images to this ControlNet unit in batch generation",
|
||||
# value=self.default_unit.loopback,
|
||||
@@ -612,6 +611,7 @@ class ControlNetUiGroup(object):
|
||||
self.batch_mask_gallery,
|
||||
self.generated_image,
|
||||
self.mask_image,
|
||||
self.hr_option,
|
||||
self.enabled,
|
||||
self.module,
|
||||
self.model,
|
||||
@@ -978,7 +978,12 @@ class ControlNetUiGroup(object):
|
||||
return
|
||||
|
||||
def register_shift_hr_options(self):
|
||||
return
|
||||
ControlNetUiGroup.a1111_context.txt2img_enable_hr.change(
|
||||
fn=lambda checked: gr.update(visible=checked),
|
||||
inputs=[ControlNetUiGroup.a1111_context.txt2img_enable_hr],
|
||||
outputs=[self.hr_option],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
def register_shift_upload_mask(self):
|
||||
"""Controls whether the upload mask input should be visible."""
|
||||
|
||||
@@ -1,5 +1,31 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
|
||||
class HiResFixOption(Enum):
|
||||
BOTH = "Both"
|
||||
LOW_RES_ONLY = "Low res only"
|
||||
HIGH_RES_ONLY = "High res only"
|
||||
|
||||
@staticmethod
|
||||
def from_value(value) -> "HiResFixOption":
|
||||
if isinstance(value, str) and value.startswith("HiResFixOption."):
|
||||
_, field = value.split(".")
|
||||
return getattr(HiResFixOption, field)
|
||||
if isinstance(value, str):
|
||||
return HiResFixOption(value)
|
||||
elif isinstance(value, int):
|
||||
return [x for x in HiResFixOption][value]
|
||||
else:
|
||||
assert isinstance(value, HiResFixOption)
|
||||
return value
|
||||
|
||||
@property
|
||||
def low_res_enabled(self) -> bool:
|
||||
return self in (HiResFixOption.BOTH, HiResFixOption.LOW_RES_ONLY)
|
||||
|
||||
@property
|
||||
def high_res_enabled(self) -> bool:
|
||||
return self in (HiResFixOption.BOTH, HiResFixOption.HIGH_RES_ONLY)
|
||||
|
||||
|
||||
class StableDiffusionVersion(Enum):
|
||||
@@ -43,25 +69,6 @@ class StableDiffusionVersion(Enum):
|
||||
)
|
||||
|
||||
|
||||
class HiResFixOption(Enum):
|
||||
BOTH = "Both"
|
||||
LOW_RES_ONLY = "Low res only"
|
||||
HIGH_RES_ONLY = "High res only"
|
||||
|
||||
@staticmethod
|
||||
def from_value(value: Any) -> "HiResFixOption":
|
||||
if isinstance(value, str) and value.startswith("HiResFixOption."):
|
||||
_, field = value.split(".")
|
||||
return getattr(HiResFixOption, field)
|
||||
if isinstance(value, str):
|
||||
return HiResFixOption(value)
|
||||
elif isinstance(value, int):
|
||||
return [x for x in HiResFixOption][value]
|
||||
else:
|
||||
assert isinstance(value, HiResFixOption)
|
||||
return value
|
||||
|
||||
|
||||
class InputMode(Enum):
|
||||
# Single image to a single ControlNet unit.
|
||||
SIMPLE = "simple"
|
||||
|
||||
@@ -4,8 +4,7 @@ from typing import List, Optional, Union, Tuple, Dict
|
||||
import numpy as np
|
||||
from modules import shared
|
||||
from lib_controlnet.logging import logger
|
||||
from lib_controlnet.enums import InputMode
|
||||
|
||||
from lib_controlnet.enums import InputMode, HiResFixOption
|
||||
from modules.api import api
|
||||
|
||||
|
||||
@@ -156,6 +155,7 @@ class UiControlNetUnit:
|
||||
batch_mask_gallery: list = [],
|
||||
generated_image: Optional[np.ndarray] = None,
|
||||
mask_image: Optional[np.ndarray] = None,
|
||||
hr_option: Union[HiResFixOption, int, str] = HiResFixOption.BOTH
|
||||
enabled: bool = True
|
||||
module: str = "None"
|
||||
model: str = "None"
|
||||
@@ -189,6 +189,7 @@ class UiControlNetUnit:
|
||||
"guidance_end",
|
||||
"pixel_perfect",
|
||||
"control_mode",
|
||||
"hr_option",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ from modules.processing import StableDiffusionProcessingImg2Img, StableDiffusion
|
||||
StableDiffusionProcessing
|
||||
from lib_controlnet.infotext import Infotext
|
||||
from modules_forge.forge_util import HWC3, numpy_to_pytorch
|
||||
from lib_controlnet.enums import HiResFixOption
|
||||
|
||||
import numpy as np
|
||||
import functools
|
||||
@@ -308,6 +309,11 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
logger.info('Batch wise input only support controlnet, control-lora, and t2i adapters!')
|
||||
break
|
||||
|
||||
if has_high_res_fix:
|
||||
hr_option = HiResFixOption.from_value(unit.hr_option)
|
||||
else:
|
||||
hr_option = HiResFixOption.BOTH
|
||||
|
||||
alignment_indices = [i % len(preprocessor_outputs) for i in range(p.batch_size)]
|
||||
if preprocessor_output_is_image:
|
||||
params.control_cond = []
|
||||
@@ -315,7 +321,8 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
|
||||
for preprocessor_output in preprocessor_outputs:
|
||||
control_cond = crop_and_resize_image(preprocessor_output, resize_mode, h, w)
|
||||
p.extra_result_images.append(external_code.visualize_inpaint_mask(control_cond))
|
||||
if hr_option.low_res_enabled:
|
||||
p.extra_result_images.append(external_code.visualize_inpaint_mask(control_cond))
|
||||
params.control_cond.append(numpy_to_pytorch(control_cond).movedim(-1, 1))
|
||||
|
||||
params.control_cond = torch.cat(params.control_cond, dim=0)[alignment_indices].contiguous()
|
||||
@@ -323,7 +330,8 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
if has_high_res_fix:
|
||||
for preprocessor_output in preprocessor_outputs:
|
||||
control_cond_for_hr_fix = crop_and_resize_image(preprocessor_output, resize_mode, hr_y, hr_x)
|
||||
p.extra_result_images.append(external_code.visualize_inpaint_mask(control_cond_for_hr_fix))
|
||||
if hr_option.high_res_enabled:
|
||||
p.extra_result_images.append(external_code.visualize_inpaint_mask(control_cond_for_hr_fix))
|
||||
params.control_cond_for_hr_fix.append(numpy_to_pytorch(control_cond_for_hr_fix).movedim(-1, 1))
|
||||
params.control_cond_for_hr_fix = torch.cat(params.control_cond_for_hr_fix, dim=0)[alignment_indices].contiguous()
|
||||
else:
|
||||
@@ -340,13 +348,15 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
for input_mask in control_masks:
|
||||
fill_border = preprocessor.fill_mask_with_one_when_resize_and_fill
|
||||
control_mask = crop_and_resize_image(input_mask, resize_mode, h, w, fill_border)
|
||||
p.extra_result_images.append(control_mask)
|
||||
if hr_option.low_res_enabled:
|
||||
p.extra_result_images.append(control_mask)
|
||||
control_mask = numpy_to_pytorch(control_mask).movedim(-1, 1)[:, :1]
|
||||
params.control_mask.append(control_mask)
|
||||
|
||||
if has_high_res_fix:
|
||||
control_mask_for_hr_fix = crop_and_resize_image(input_mask, resize_mode, hr_y, hr_x, fill_border)
|
||||
p.extra_result_images.append(control_mask_for_hr_fix)
|
||||
if hr_option.high_res_enabled:
|
||||
p.extra_result_images.append(control_mask_for_hr_fix)
|
||||
control_mask_for_hr_fix = numpy_to_pytorch(control_mask_for_hr_fix).movedim(-1, 1)[:, :1]
|
||||
params.control_mask_for_hr_fix.append(control_mask_for_hr_fix)
|
||||
|
||||
@@ -382,6 +392,24 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
|
||||
is_hr_pass = getattr(p, 'is_hr_pass', False)
|
||||
|
||||
has_high_res_fix = (
|
||||
isinstance(p, StableDiffusionProcessingTxt2Img)
|
||||
and getattr(p, 'enable_hr', False)
|
||||
)
|
||||
|
||||
if has_high_res_fix:
|
||||
hr_option = HiResFixOption.from_value(unit.hr_option)
|
||||
else:
|
||||
hr_option = HiResFixOption.BOTH
|
||||
|
||||
if has_high_res_fix and is_hr_pass and (not hr_option.high_res_enabled):
|
||||
logger.info(f"ControlNet Skipped High-res pass.")
|
||||
return
|
||||
|
||||
if has_high_res_fix and (not is_hr_pass) and (not hr_option.low_res_enabled):
|
||||
logger.info(f"ControlNet Skipped Low-res pass.")
|
||||
return
|
||||
|
||||
if is_hr_pass:
|
||||
cond = params.control_cond_for_hr_fix
|
||||
mask = params.control_mask_for_hr_fix
|
||||
|
||||
Reference in New Issue
Block a user