mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-28 10:11:42 +00:00
Add back ControlNet HR option (#90)
* Add back ControlNet HR option * nits * protect kernel * add enum * fix * fix * Update controlnet.py * restore controlnet.py * hint ui * Update controlnet.py * fix * Update controlnet.py * Backend: better controlnet mask batch broadcasting * Update README.md * fix inpaint batch dim align #94 * fix sigmas device in rare cases #71 * rework sigma device mapping * Add hr_option to infotext --------- Co-authored-by: lllyasviel <lyuminzhang@outlook.com>
This commit is contained in:
@@ -20,6 +20,7 @@ from modules.processing import StableDiffusionProcessingImg2Img, StableDiffusion
|
||||
StableDiffusionProcessing
|
||||
from lib_controlnet.infotext import Infotext
|
||||
from modules_forge.forge_util import HWC3, numpy_to_pytorch
|
||||
from lib_controlnet.enums import HiResFixOption
|
||||
|
||||
import numpy as np
|
||||
import functools
|
||||
@@ -308,6 +309,11 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
logger.info('Batch wise input only support controlnet, control-lora, and t2i adapters!')
|
||||
break
|
||||
|
||||
if has_high_res_fix:
|
||||
hr_option = HiResFixOption.from_value(unit.hr_option)
|
||||
else:
|
||||
hr_option = HiResFixOption.BOTH
|
||||
|
||||
alignment_indices = [i % len(preprocessor_outputs) for i in range(p.batch_size)]
|
||||
if preprocessor_output_is_image:
|
||||
params.control_cond = []
|
||||
@@ -315,7 +321,8 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
|
||||
for preprocessor_output in preprocessor_outputs:
|
||||
control_cond = crop_and_resize_image(preprocessor_output, resize_mode, h, w)
|
||||
p.extra_result_images.append(external_code.visualize_inpaint_mask(control_cond))
|
||||
if hr_option.low_res_enabled:
|
||||
p.extra_result_images.append(external_code.visualize_inpaint_mask(control_cond))
|
||||
params.control_cond.append(numpy_to_pytorch(control_cond).movedim(-1, 1))
|
||||
|
||||
params.control_cond = torch.cat(params.control_cond, dim=0)[alignment_indices].contiguous()
|
||||
@@ -323,7 +330,8 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
if has_high_res_fix:
|
||||
for preprocessor_output in preprocessor_outputs:
|
||||
control_cond_for_hr_fix = crop_and_resize_image(preprocessor_output, resize_mode, hr_y, hr_x)
|
||||
p.extra_result_images.append(external_code.visualize_inpaint_mask(control_cond_for_hr_fix))
|
||||
if hr_option.high_res_enabled:
|
||||
p.extra_result_images.append(external_code.visualize_inpaint_mask(control_cond_for_hr_fix))
|
||||
params.control_cond_for_hr_fix.append(numpy_to_pytorch(control_cond_for_hr_fix).movedim(-1, 1))
|
||||
params.control_cond_for_hr_fix = torch.cat(params.control_cond_for_hr_fix, dim=0)[alignment_indices].contiguous()
|
||||
else:
|
||||
@@ -340,13 +348,15 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
for input_mask in control_masks:
|
||||
fill_border = preprocessor.fill_mask_with_one_when_resize_and_fill
|
||||
control_mask = crop_and_resize_image(input_mask, resize_mode, h, w, fill_border)
|
||||
p.extra_result_images.append(control_mask)
|
||||
if hr_option.low_res_enabled:
|
||||
p.extra_result_images.append(control_mask)
|
||||
control_mask = numpy_to_pytorch(control_mask).movedim(-1, 1)[:, :1]
|
||||
params.control_mask.append(control_mask)
|
||||
|
||||
if has_high_res_fix:
|
||||
control_mask_for_hr_fix = crop_and_resize_image(input_mask, resize_mode, hr_y, hr_x, fill_border)
|
||||
p.extra_result_images.append(control_mask_for_hr_fix)
|
||||
if hr_option.high_res_enabled:
|
||||
p.extra_result_images.append(control_mask_for_hr_fix)
|
||||
control_mask_for_hr_fix = numpy_to_pytorch(control_mask_for_hr_fix).movedim(-1, 1)[:, :1]
|
||||
params.control_mask_for_hr_fix.append(control_mask_for_hr_fix)
|
||||
|
||||
@@ -382,6 +392,24 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
|
||||
is_hr_pass = getattr(p, 'is_hr_pass', False)
|
||||
|
||||
has_high_res_fix = (
|
||||
isinstance(p, StableDiffusionProcessingTxt2Img)
|
||||
and getattr(p, 'enable_hr', False)
|
||||
)
|
||||
|
||||
if has_high_res_fix:
|
||||
hr_option = HiResFixOption.from_value(unit.hr_option)
|
||||
else:
|
||||
hr_option = HiResFixOption.BOTH
|
||||
|
||||
if has_high_res_fix and is_hr_pass and (not hr_option.high_res_enabled):
|
||||
logger.info(f"ControlNet Skipped High-res pass.")
|
||||
return
|
||||
|
||||
if has_high_res_fix and (not is_hr_pass) and (not hr_option.low_res_enabled):
|
||||
logger.info(f"ControlNet Skipped Low-res pass.")
|
||||
return
|
||||
|
||||
if is_hr_pass:
|
||||
cond = params.control_cond_for_hr_fix
|
||||
mask = params.control_mask_for_hr_fix
|
||||
|
||||
Reference in New Issue
Block a user