Add back ControlNet HR option (#90)

* Add back ControlNet HR option

* nits

* protect kernel

* add enum

* fix

* fix

* Update controlnet.py

* restore controlnet.py

* hint ui

* Update controlnet.py

* fix

* Update controlnet.py

* Backend: better controlnet mask batch broadcasting

* Update README.md

* fix inpaint batch dim align #94

* fix sigmas device in rare cases #71

* rework sigma device mapping

* Add hr_option to infotext

---------

Co-authored-by: lllyasviel <lyuminzhang@outlook.com>
This commit is contained in:
Chenlei Hu
2024-02-07 16:09:52 +00:00
committed by GitHub
parent 257ac2653a
commit e1faf8327b
4 changed files with 79 additions and 38 deletions

View File

@@ -20,6 +20,7 @@ from modules.processing import StableDiffusionProcessingImg2Img, StableDiffusion
StableDiffusionProcessing
from lib_controlnet.infotext import Infotext
from modules_forge.forge_util import HWC3, numpy_to_pytorch
from lib_controlnet.enums import HiResFixOption
import numpy as np
import functools
@@ -308,6 +309,11 @@ class ControlNetForForgeOfficial(scripts.Script):
logger.info('Batch wise input only support controlnet, control-lora, and t2i adapters!')
break
if has_high_res_fix:
hr_option = HiResFixOption.from_value(unit.hr_option)
else:
hr_option = HiResFixOption.BOTH
alignment_indices = [i % len(preprocessor_outputs) for i in range(p.batch_size)]
if preprocessor_output_is_image:
params.control_cond = []
@@ -315,7 +321,8 @@ class ControlNetForForgeOfficial(scripts.Script):
for preprocessor_output in preprocessor_outputs:
control_cond = crop_and_resize_image(preprocessor_output, resize_mode, h, w)
p.extra_result_images.append(external_code.visualize_inpaint_mask(control_cond))
if hr_option.low_res_enabled:
p.extra_result_images.append(external_code.visualize_inpaint_mask(control_cond))
params.control_cond.append(numpy_to_pytorch(control_cond).movedim(-1, 1))
params.control_cond = torch.cat(params.control_cond, dim=0)[alignment_indices].contiguous()
@@ -323,7 +330,8 @@ class ControlNetForForgeOfficial(scripts.Script):
if has_high_res_fix:
for preprocessor_output in preprocessor_outputs:
control_cond_for_hr_fix = crop_and_resize_image(preprocessor_output, resize_mode, hr_y, hr_x)
p.extra_result_images.append(external_code.visualize_inpaint_mask(control_cond_for_hr_fix))
if hr_option.high_res_enabled:
p.extra_result_images.append(external_code.visualize_inpaint_mask(control_cond_for_hr_fix))
params.control_cond_for_hr_fix.append(numpy_to_pytorch(control_cond_for_hr_fix).movedim(-1, 1))
params.control_cond_for_hr_fix = torch.cat(params.control_cond_for_hr_fix, dim=0)[alignment_indices].contiguous()
else:
@@ -340,13 +348,15 @@ class ControlNetForForgeOfficial(scripts.Script):
for input_mask in control_masks:
fill_border = preprocessor.fill_mask_with_one_when_resize_and_fill
control_mask = crop_and_resize_image(input_mask, resize_mode, h, w, fill_border)
p.extra_result_images.append(control_mask)
if hr_option.low_res_enabled:
p.extra_result_images.append(control_mask)
control_mask = numpy_to_pytorch(control_mask).movedim(-1, 1)[:, :1]
params.control_mask.append(control_mask)
if has_high_res_fix:
control_mask_for_hr_fix = crop_and_resize_image(input_mask, resize_mode, hr_y, hr_x, fill_border)
p.extra_result_images.append(control_mask_for_hr_fix)
if hr_option.high_res_enabled:
p.extra_result_images.append(control_mask_for_hr_fix)
control_mask_for_hr_fix = numpy_to_pytorch(control_mask_for_hr_fix).movedim(-1, 1)[:, :1]
params.control_mask_for_hr_fix.append(control_mask_for_hr_fix)
@@ -382,6 +392,24 @@ class ControlNetForForgeOfficial(scripts.Script):
is_hr_pass = getattr(p, 'is_hr_pass', False)
has_high_res_fix = (
isinstance(p, StableDiffusionProcessingTxt2Img)
and getattr(p, 'enable_hr', False)
)
if has_high_res_fix:
hr_option = HiResFixOption.from_value(unit.hr_option)
else:
hr_option = HiResFixOption.BOTH
if has_high_res_fix and is_hr_pass and (not hr_option.high_res_enabled):
logger.info(f"ControlNet Skipped High-res pass.")
return
if has_high_res_fix and (not is_hr_pass) and (not hr_option.low_res_enabled):
logger.info(f"ControlNet Skipped Low-res pass.")
return
if is_hr_pass:
cond = params.control_cond_for_hr_fix
mask = params.control_mask_for_hr_fix