diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/controlnet_ui_group.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/controlnet_ui_group.py index dab2ebbd..9c3f2df6 100644 --- a/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/controlnet_ui_group.py +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/controlnet_ui_group.py @@ -16,7 +16,7 @@ from lib_controlnet.controlnet_ui.openpose_editor import OpenposeEditor from lib_controlnet.controlnet_ui.preset import ControlNetPresetUI from lib_controlnet.controlnet_ui.tool_button import ToolButton from lib_controlnet.controlnet_ui.photopea import Photopea -from lib_controlnet.enums import InputMode +from lib_controlnet.enums import InputMode, HiResFixOption from modules import shared from modules.ui_components import FormRow from modules_forge.forge_util import HWC3 @@ -46,7 +46,6 @@ class A1111Context: img2img_inpaint_upload_tab: Optional[gr.components.IOComponent] = None img2img_inpaint_area: Optional[gr.components.IOComponent] = None - # txt2img_enable_hr is only available for A1111 > 1.7.0. txt2img_enable_hr: Optional[gr.components.IOComponent] = None setting_sd_model_checkpoint: Optional[gr.components.IOComponent] = None @@ -580,15 +579,15 @@ class ControlNetUiGroup(object): visible=not self.is_img2img, ) - # self.hr_option = gr.Radio( - # choices=[e.value for e in external_code.HiResFixOption], - # value=self.default_unit.hr_option.value, - # label="Hires-Fix Option", - # elem_id=f"{elem_id_tabname}_{tabname}_controlnet_hr_option_radio", - # elem_classes="controlnet_hr_option_radio", - # visible=False, - # ) - # + self.hr_option = gr.Radio( + choices=[e.value for e in HiResFixOption], + value=self.default_unit.hr_option.value, + label="Hires-Fix Option", + elem_id=f"{elem_id_tabname}_{tabname}_controlnet_hr_option_radio", + elem_classes="controlnet_hr_option_radio", + visible=False, + ) + # self.loopback = gr.Checkbox( # label="[Batch Loopback] Automatically send generated images to this ControlNet unit in batch generation", # value=self.default_unit.loopback, @@ -612,6 +611,7 @@ class ControlNetUiGroup(object): self.batch_mask_gallery, self.generated_image, self.mask_image, + self.hr_option, self.enabled, self.module, self.model, @@ -978,7 +978,12 @@ class ControlNetUiGroup(object): return def register_shift_hr_options(self): - return + ControlNetUiGroup.a1111_context.txt2img_enable_hr.change( + fn=lambda checked: gr.update(visible=checked), + inputs=[ControlNetUiGroup.a1111_context.txt2img_enable_hr], + outputs=[self.hr_option], + show_progress=False, + ) def register_shift_upload_mask(self): """Controls whether the upload mask input should be visible.""" diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/enums.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/enums.py index 05dc8a63..ad64c321 100644 --- a/extensions-builtin/sd_forge_controlnet/lib_controlnet/enums.py +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/enums.py @@ -1,5 +1,31 @@ from enum import Enum -from typing import Any + + +class HiResFixOption(Enum): + BOTH = "Both" + LOW_RES_ONLY = "Low res only" + HIGH_RES_ONLY = "High res only" + + @staticmethod + def from_value(value) -> "HiResFixOption": + if isinstance(value, str) and value.startswith("HiResFixOption."): + _, field = value.split(".") + return getattr(HiResFixOption, field) + if isinstance(value, str): + return HiResFixOption(value) + elif isinstance(value, int): + return [x for x in HiResFixOption][value] + else: + assert isinstance(value, HiResFixOption) + return value + + @property + def low_res_enabled(self) -> bool: + return self in (HiResFixOption.BOTH, HiResFixOption.LOW_RES_ONLY) + + @property + def high_res_enabled(self) -> bool: + return self in (HiResFixOption.BOTH, HiResFixOption.HIGH_RES_ONLY) class StableDiffusionVersion(Enum): @@ -43,25 +69,6 @@ class StableDiffusionVersion(Enum): ) -class HiResFixOption(Enum): - BOTH = "Both" - LOW_RES_ONLY = "Low res only" - HIGH_RES_ONLY = "High res only" - - @staticmethod - def from_value(value: Any) -> "HiResFixOption": - if isinstance(value, str) and value.startswith("HiResFixOption."): - _, field = value.split(".") - return getattr(HiResFixOption, field) - if isinstance(value, str): - return HiResFixOption(value) - elif isinstance(value, int): - return [x for x in HiResFixOption][value] - else: - assert isinstance(value, HiResFixOption) - return value - - class InputMode(Enum): # Single image to a single ControlNet unit. SIMPLE = "simple" diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/external_code.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/external_code.py index df572028..a5d2bb4f 100644 --- a/extensions-builtin/sd_forge_controlnet/lib_controlnet/external_code.py +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/external_code.py @@ -4,8 +4,7 @@ from typing import List, Optional, Union, Tuple, Dict import numpy as np from modules import shared from lib_controlnet.logging import logger -from lib_controlnet.enums import InputMode - +from lib_controlnet.enums import InputMode, HiResFixOption from modules.api import api @@ -156,6 +155,7 @@ class UiControlNetUnit: batch_mask_gallery: list = [], generated_image: Optional[np.ndarray] = None, mask_image: Optional[np.ndarray] = None, + hr_option: Union[HiResFixOption, int, str] = HiResFixOption.BOTH enabled: bool = True module: str = "None" model: str = "None" @@ -189,6 +189,7 @@ class UiControlNetUnit: "guidance_end", "pixel_perfect", "control_mode", + "hr_option", ) diff --git a/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py b/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py index 8f0a7b85..904937ff 100644 --- a/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py +++ b/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py @@ -20,6 +20,7 @@ from modules.processing import StableDiffusionProcessingImg2Img, StableDiffusion StableDiffusionProcessing from lib_controlnet.infotext import Infotext from modules_forge.forge_util import HWC3, numpy_to_pytorch +from lib_controlnet.enums import HiResFixOption import numpy as np import functools @@ -308,6 +309,11 @@ class ControlNetForForgeOfficial(scripts.Script): logger.info('Batch wise input only support controlnet, control-lora, and t2i adapters!') break + if has_high_res_fix: + hr_option = HiResFixOption.from_value(unit.hr_option) + else: + hr_option = HiResFixOption.BOTH + alignment_indices = [i % len(preprocessor_outputs) for i in range(p.batch_size)] if preprocessor_output_is_image: params.control_cond = [] @@ -315,7 +321,8 @@ class ControlNetForForgeOfficial(scripts.Script): for preprocessor_output in preprocessor_outputs: control_cond = crop_and_resize_image(preprocessor_output, resize_mode, h, w) - p.extra_result_images.append(external_code.visualize_inpaint_mask(control_cond)) + if hr_option.low_res_enabled: + p.extra_result_images.append(external_code.visualize_inpaint_mask(control_cond)) params.control_cond.append(numpy_to_pytorch(control_cond).movedim(-1, 1)) params.control_cond = torch.cat(params.control_cond, dim=0)[alignment_indices].contiguous() @@ -323,7 +330,8 @@ class ControlNetForForgeOfficial(scripts.Script): if has_high_res_fix: for preprocessor_output in preprocessor_outputs: control_cond_for_hr_fix = crop_and_resize_image(preprocessor_output, resize_mode, hr_y, hr_x) - p.extra_result_images.append(external_code.visualize_inpaint_mask(control_cond_for_hr_fix)) + if hr_option.high_res_enabled: + p.extra_result_images.append(external_code.visualize_inpaint_mask(control_cond_for_hr_fix)) params.control_cond_for_hr_fix.append(numpy_to_pytorch(control_cond_for_hr_fix).movedim(-1, 1)) params.control_cond_for_hr_fix = torch.cat(params.control_cond_for_hr_fix, dim=0)[alignment_indices].contiguous() else: @@ -340,13 +348,15 @@ class ControlNetForForgeOfficial(scripts.Script): for input_mask in control_masks: fill_border = preprocessor.fill_mask_with_one_when_resize_and_fill control_mask = crop_and_resize_image(input_mask, resize_mode, h, w, fill_border) - p.extra_result_images.append(control_mask) + if hr_option.low_res_enabled: + p.extra_result_images.append(control_mask) control_mask = numpy_to_pytorch(control_mask).movedim(-1, 1)[:, :1] params.control_mask.append(control_mask) if has_high_res_fix: control_mask_for_hr_fix = crop_and_resize_image(input_mask, resize_mode, hr_y, hr_x, fill_border) - p.extra_result_images.append(control_mask_for_hr_fix) + if hr_option.high_res_enabled: + p.extra_result_images.append(control_mask_for_hr_fix) control_mask_for_hr_fix = numpy_to_pytorch(control_mask_for_hr_fix).movedim(-1, 1)[:, :1] params.control_mask_for_hr_fix.append(control_mask_for_hr_fix) @@ -382,6 +392,24 @@ class ControlNetForForgeOfficial(scripts.Script): is_hr_pass = getattr(p, 'is_hr_pass', False) + has_high_res_fix = ( + isinstance(p, StableDiffusionProcessingTxt2Img) + and getattr(p, 'enable_hr', False) + ) + + if has_high_res_fix: + hr_option = HiResFixOption.from_value(unit.hr_option) + else: + hr_option = HiResFixOption.BOTH + + if has_high_res_fix and is_hr_pass and (not hr_option.high_res_enabled): + logger.info(f"ControlNet Skipped High-res pass.") + return + + if has_high_res_fix and (not is_hr_pass) and (not hr_option.low_res_enabled): + logger.info(f"ControlNet Skipped Low-res pass.") + return + if is_hr_pass: cond = params.control_cond_for_hr_fix mask = params.control_mask_for_hr_fix