rework args

This commit is contained in:
lllyasviel
2024-02-01 17:18:59 -08:00
parent 06759ccd7f
commit 9bcea66d08
2 changed files with 45 additions and 154 deletions

View File

@@ -20,6 +20,7 @@ from lib_controlnet.enums import InputMode
from modules import shared
from modules.ui_components import FormRow
from modules_forge.forge_util import HWC3
from lib_controlnet.external_code import UiControlNetUnit
@dataclass
@@ -118,53 +119,6 @@ class A1111Context:
)
class UiControlNetUnit(external_code.ControlNetUnit):
"""The data class that stores all states of a ControlNetUnit."""
def __init__(
self,
input_mode: InputMode = InputMode.SIMPLE,
batch_images: Optional[Union[str, List[external_code.InputImage]]] = None,
output_dir: str = "",
loopback: bool = False,
merge_gallery_files: List[
Dict[Union[Literal["name"], Literal["data"]], str]
] = [],
use_preview_as_input: bool = False,
generated_image: Optional[np.ndarray] = None,
mask_image: Optional[np.ndarray] = None,
enabled: bool = True,
module: Optional[str] = None,
model: Optional[str] = None,
weight: float = 1.0,
image: Optional[Dict[str, np.ndarray]] = None,
*args,
**kwargs,
):
if use_preview_as_input and generated_image is not None:
input_image = generated_image
module = "None"
else:
input_image = image
# Prefer uploaded mask_image over hand-drawn mask.
if input_image is not None and mask_image is not None:
assert isinstance(input_image, dict)
input_image["mask"] = mask_image
if merge_gallery_files and input_mode == InputMode.MERGE:
input_image = [
{"image": read_image(file["name"])} for file in merge_gallery_files
]
super().__init__(enabled, module, model, weight, input_image, *args, **kwargs)
self.is_ui = True
self.input_mode = input_mode
self.batch_images = batch_images
self.output_dir = output_dir
self.loopback = loopback
class ControlNetUiGroup(object):
refresh_symbol = "\U0001f504" # 🔄
switch_values_symbol = "\U000021C5" # ⇅
@@ -237,7 +191,6 @@ class ControlNetUiGroup(object):
self.webcam_mirror = None
self.send_dimen_button = None
self.enabled = None
self.low_vram = None
self.pixel_perfect = None
self.preprocessor_preview = None
self.mask_upload = None
@@ -255,7 +208,6 @@ class ControlNetUiGroup(object):
self.threshold_b = None
self.control_mode = None
self.resize_mode = None
self.loopback = None
self.use_preview_as_input = None
self.openpose_editor = None
self.preset_panel = None
@@ -448,12 +400,12 @@ class ControlNetUiGroup(object):
elem_id=f"{elem_id_tabname}_{tabname}_controlnet_enable_checkbox",
elem_classes=["cnet-unit-enabled"],
)
self.low_vram = gr.Checkbox(
label="Low VRAM",
value=self.default_unit.low_vram,
elem_id=f"{elem_id_tabname}_{tabname}_controlnet_low_vram_checkbox",
visible=False, # Not needed now
)
# self.low_vram = gr.Checkbox(
# label="Low VRAM",
# value=self.default_unit.low_vram,
# elem_id=f"{elem_id_tabname}_{tabname}_controlnet_low_vram_checkbox",
# visible=False, # Not needed now
# )
self.pixel_perfect = gr.Checkbox(
label="Pixel Perfect",
value=self.default_unit.pixel_perfect,
@@ -611,22 +563,22 @@ class ControlNetUiGroup(object):
visible=not self.is_img2img,
)
self.hr_option = gr.Radio(
choices=[e.value for e in external_code.HiResFixOption],
value=self.default_unit.hr_option.value,
label="Hires-Fix Option",
elem_id=f"{elem_id_tabname}_{tabname}_controlnet_hr_option_radio",
elem_classes="controlnet_hr_option_radio",
visible=False,
)
self.loopback = gr.Checkbox(
label="[Batch Loopback] Automatically send generated images to this ControlNet unit in batch generation",
value=self.default_unit.loopback,
elem_id=f"{elem_id_tabname}_{tabname}_controlnet_automatically_send_generated_images_checkbox",
elem_classes="controlnet_loopback_checkbox",
visible=False,
)
# self.hr_option = gr.Radio(
# choices=[e.value for e in external_code.HiResFixOption],
# value=self.default_unit.hr_option.value,
# label="Hires-Fix Option",
# elem_id=f"{elem_id_tabname}_{tabname}_controlnet_hr_option_radio",
# elem_classes="controlnet_hr_option_radio",
# visible=False,
# )
#
# self.loopback = gr.Checkbox(
# label="[Batch Loopback] Automatically send generated images to this ControlNet unit in batch generation",
# value=self.default_unit.loopback,
# elem_id=f"{elem_id_tabname}_{tabname}_controlnet_automatically_send_generated_images_checkbox",
# elem_classes="controlnet_loopback_checkbox",
# visible=False,
# )
self.preset_panel = ControlNetPresetUI(
id_prefix=f"{elem_id_tabname}_{tabname}_"
@@ -636,24 +588,15 @@ class ControlNetUiGroup(object):
self.output_dir_state = gr.State("")
unit_args = (
self.input_mode,
self.batch_image_dir_state,
self.output_dir_state,
self.loopback,
# Non-persistent fields.
# Following inputs will not be persistent on `ControlNetUnit`.
# They are only used during object construction.
self.merge_gallery,
self.use_preview_as_input,
self.generated_image,
self.mask_image,
# End of Non-persistent fields.
self.enabled,
self.module,
self.model,
self.weight,
self.image,
self.resize_mode,
self.low_vram,
self.processor_res,
self.threshold_a,
self.threshold_b,
@@ -661,8 +604,6 @@ class ControlNetUiGroup(object):
self.guidance_end,
self.pixel_perfect,
self.control_mode,
self.inpaint_crop_input_image,
self.hr_option,
)
unit = gr.State(self.default_unit)
@@ -1018,7 +959,7 @@ class ControlNetUiGroup(object):
gr.update(value=None),
gr.update(value=None),
gr.update(value=False, visible=x),
] + [gr.update(visible=x)] * 4
] + [gr.update(visible=x)] * 3
self.upload_independent_img_in_img2img.change(
fn_same_checked,
@@ -1029,7 +970,6 @@ class ControlNetUiGroup(object):
self.preprocessor_preview,
self.image_upload_panel,
self.trigger_preprocessor,
self.loopback,
self.resize_mode,
],
show_progress=False,
@@ -1072,16 +1012,17 @@ class ControlNetUiGroup(object):
ControlNetUiGroup.a1111_context.img2img_inpaint_area.change(**gradio_kwargs)
def register_shift_hr_options(self):
# A1111 version < 1.6.0.
if not ControlNetUiGroup.a1111_context.txt2img_enable_hr:
return
ControlNetUiGroup.a1111_context.txt2img_enable_hr.change(
fn=lambda checked: gr.update(visible=checked),
inputs=[ControlNetUiGroup.a1111_context.txt2img_enable_hr],
outputs=[self.hr_option],
show_progress=False,
)
# # A1111 version < 1.6.0.
# if not ControlNetUiGroup.a1111_context.txt2img_enable_hr:
# return
#
# ControlNetUiGroup.a1111_context.txt2img_enable_hr.change(
# fn=lambda checked: gr.update(visible=checked),
# inputs=[ControlNetUiGroup.a1111_context.txt2img_enable_hr],
# outputs=[self.hr_option],
# show_progress=False,
# )
return
def register_shift_upload_mask(self):
"""Controls whether the upload mask input should be visible."""
@@ -1268,11 +1209,11 @@ class ControlNetUiGroup(object):
for ui_group in ui_groups:
batch_fn = lambda: InputMode.BATCH
simple_fn = lambda: InputMode.SIMPLE
# merge_fn = lambda: InputMode.MERGE
merge_fn = lambda: InputMode.MERGE
for input_tab, fn in (
(ui_group.upload_tab, simple_fn),
(ui_group.batch_tab, batch_fn),
# (ui_group.merge_tab, merge_fn),
(ui_group.merge_tab, merge_fn),
):
# Sync input_mode.
input_tab.select(
@@ -1280,19 +1221,6 @@ class ControlNetUiGroup(object):
inputs=[],
outputs=[ui_group.input_mode],
show_progress=False,
).then(
# Update visibility of loopback checkbox.
fn=lambda *mode_values: (
(
gr.update(
visible=any(m == InputMode.BATCH for m in mode_values)
),
)
* len(ui_groups)
),
inputs=[g.input_mode for g in ui_groups],
outputs=[g.loopback for g in ui_groups],
show_progress=False,
)
@staticmethod

View File

@@ -6,7 +6,7 @@ import numpy as np
from modules import scripts, processing, shared
from lib_controlnet import global_state
from lib_controlnet.logging import logger
from lib_controlnet.enums import HiResFixOption
from lib_controlnet.enums import HiResFixOption, InputMode
from modules.api import api
@@ -149,17 +149,17 @@ InputImage = Union[Dict[str, InputImage], Tuple[InputImage, InputImage], InputIm
@dataclass
class ControlNetUnit:
"""
Represents an entire ControlNet processing unit.
"""
class UiControlNetUnit:
input_mode: InputMode = InputMode.SIMPLE
use_preview_as_input: bool = False,
generated_image: Optional[np.ndarray] = None,
mask_image: Optional[np.ndarray] = None,
enabled: bool = True
module: str = "None"
model: str = "None"
weight: float = 1.0
image: Optional[Union[InputImage, List[InputImage]]] = None
resize_mode: Union[ResizeMode, int, str] = ResizeMode.INNER_FIT
low_vram: bool = False
processor_res: int = -1
threshold_a: float = -1
threshold_b: float = -1
@@ -167,47 +167,10 @@ class ControlNetUnit:
guidance_end: float = 1.0
pixel_perfect: bool = False
control_mode: Union[ControlMode, int, str] = ControlMode.BALANCED
# Whether to crop input image based on A1111 img2img mask. This flag is only used when `inpaint area`
# in A1111 is set to `Only masked`. In API, this correspond to `inpaint_full_res = True`.
inpaint_crop_input_image: bool = True
# If hires fix is enabled in A1111, how should this ControlNet unit be applied.
# The value is ignored if the generation is not using hires fix.
hr_option: Union[HiResFixOption, int, str] = HiResFixOption.BOTH
# Whether save the detected map of this unit. Setting this option to False prevents saving the
# detected map or sending detected map along with generated images via API.
# Currently the option is only accessible in API calls.
save_detected_map: bool = True
# Weight for each layer of ControlNet params.
# For ControlNet:
# - SD1.5: 13 weights (4 encoder block * 3 + 1 middle block)
# - SDXL: 10 weights (3 encoder block * 3 + 1 middle block)
# For T2IAdapter
# - SD1.5: 5 weights (4 encoder block + 1 middle block)
# - SDXL: 4 weights (3 encoder block + 1 middle block)
# Note1: Setting advanced weighting will disable `soft_injection`, i.e.
# It is recommended to set ControlMode = BALANCED when using `advanced_weighting`.
# Note2: The field `weight` is still used in some places, e.g. reference_only,
# even advanced_weighting is set.
advanced_weighting: Optional[List[float]] = None
def __eq__(self, other):
if not isinstance(other, ControlNetUnit):
return False
return vars(self) == vars(other)
def accepts_multiple_inputs(self) -> bool:
"""This unit can accept multiple input images."""
return self.module in (
"ip-adapter_clip_sdxl",
"ip-adapter_clip_sdxl_plus_vith",
"ip-adapter_clip_sd15",
"ip-adapter_face_id",
"ip-adapter_face_id_plus",
"instant_id_face_embedding",
)
# Backward Compatible
ControlNetUnit = UiControlNetUnit
def to_base64_nparray(encoding: str):