Fix inpaint mask in API (#188)

* Fix inpaint mask in API

* Add more tests

* Add tests
This commit is contained in:
Chenlei Hu
2024-02-11 16:49:21 +00:00
committed by GitHub
parent e11753ff84
commit 8316773caa
3 changed files with 102 additions and 5 deletions

View File

@@ -188,8 +188,8 @@ class ControlNetUnit:
# ====== End of UI only fields ======
# Following fields are used in both the API and the UI.
# Holds the mask image as a NumPy array; defaults to None.
mask_image: Optional[np.ndarray] = None
# Holds the mask image; defaults to None.
mask_image: Optional[GradioImageMaskPair] = None
# Specifies how this unit should be applied in each pass of high-resolution fix.
# Ignored if high-resolution fix is not enabled.
hr_option: Union[HiResFixOption, int, str] = HiResFixOption.BOTH
@@ -262,7 +262,19 @@ class ControlNetUnit:
"mask": np.zeros_like(img),
}
if isinstance(unit.mask_image, str):
unit.mask_image = np.array(api.decode_base64_to_image(unit.mask_image)).astype('uint8')
mask = np.array(api.decode_base64_to_image(unit.mask_image)).astype('uint8')
if unit.image is not None:
# Attach mask on image if ControlNet has input image.
assert isinstance(unit.image, dict)
unit.image["mask"] = mask
unit.mask_image = None
else:
# Otherwise, wire to standalone mask.
# This happens in img2img when using A1111 img2img input.
unit.mask_image = {
"image": mask,
"mask": np.zeros_like(mask),
}
return unit

View File

@@ -3,6 +3,9 @@ import pytest
from .template import (
APITestTemplate,
girl_img,
mask_img,
disable_in_cq,
get_model,
)
@@ -83,3 +86,86 @@ def test_save_map(gen_type, save_map):
unit_overrides={"save_detected_map": save_map},
input_image=girl_img,
).exec(expected_output_num=2 if save_map else 1)
@disable_in_cq
def test_masked_controlnet_txt2img():
assert APITestTemplate(
f"test_masked_controlnet_txt2img",
"txt2img",
payload_overrides={},
unit_overrides={
"image": girl_img,
"mask_image": mask_img,
},
).exec()
@disable_in_cq
def test_masked_controlnet_img2img():
assert APITestTemplate(
f"test_masked_controlnet_img2img",
"img2img",
payload_overrides={
"init_images": [girl_img],
},
# Note: Currently you must give ControlNet unit input image to specify
# mask.
# TODO: Fix this for img2img.
unit_overrides={
"image": girl_img,
"mask_image": mask_img,
},
).exec()
@disable_in_cq
def test_txt2img_inpaint():
assert APITestTemplate(
"txt2img_inpaint",
"txt2img",
payload_overrides={},
unit_overrides={
"image": girl_img,
"mask_image": mask_img,
"model": get_model("v11p_sd15_inpaint"),
"module": "inpaint_only",
},
).exec()
@disable_in_cq
def test_img2img_inpaint():
assert APITestTemplate(
"img2img_inpaint",
"img2img",
payload_overrides={
"init_images": [girl_img],
"mask": mask_img,
},
unit_overrides={
"model": get_model("v11p_sd15_inpaint"),
"module": "inpaint_only",
},
).exec()
# Currently failing.
# TODO Fix lama outpaint.
@disable_in_cq
def test_lama_outpaint():
assert APITestTemplate(
"txt2img_lama_outpaint",
"txt2img",
payload_overrides={
"width": 768,
"height": 768,
},
# Outpaint should not need a mask.
unit_overrides={
"image": girl_img,
"model": get_model("v11p_sd15_inpaint"),
"module": "inpaint_only+lama",
"resize_mode": "Resize and Fill", # OUTER_FIT
},
).exec()

View File

@@ -252,14 +252,13 @@ default_unit = {
"enabled": True,
"guidance_end": 1,
"guidance_start": 0,
"low_vram": False,
"pixel_perfect": True,
"processor_res": 512,
"resize_mode": 1,
"threshold_a": 64,
"threshold_b": 64,
"weight": 1,
"module": "None",
"module": "canny",
"model": get_model("sd15_canny"),
}