diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/external_code.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/external_code.py index c48f238a..4954478a 100644 --- a/extensions-builtin/sd_forge_controlnet/lib_controlnet/external_code.py +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/external_code.py @@ -188,8 +188,8 @@ class ControlNetUnit: # ====== End of UI only fields ====== # Following fields are used in both the API and the UI. - # Holds the mask image as a NumPy array; defaults to None. - mask_image: Optional[np.ndarray] = None + # Holds the mask image; defaults to None. + mask_image: Optional[GradioImageMaskPair] = None # Specifies how this unit should be applied in each pass of high-resolution fix. # Ignored if high-resolution fix is not enabled. hr_option: Union[HiResFixOption, int, str] = HiResFixOption.BOTH @@ -262,7 +262,19 @@ class ControlNetUnit: "mask": np.zeros_like(img), } if isinstance(unit.mask_image, str): - unit.mask_image = np.array(api.decode_base64_to_image(unit.mask_image)).astype('uint8') + mask = np.array(api.decode_base64_to_image(unit.mask_image)).astype('uint8') + if unit.image is not None: + # Attach mask on image if ControlNet has input image. + assert isinstance(unit.image, dict) + unit.image["mask"] = mask + unit.mask_image = None + else: + # Otherwise, wire to standalone mask. + # This happens in img2img when using A1111 img2img input. + unit.mask_image = { + "image": mask, + "mask": np.zeros_like(mask), + } return unit diff --git a/extensions-builtin/sd_forge_controlnet/tests/web_api/generation_test.py b/extensions-builtin/sd_forge_controlnet/tests/web_api/generation_test.py index 9656b266..433819d1 100644 --- a/extensions-builtin/sd_forge_controlnet/tests/web_api/generation_test.py +++ b/extensions-builtin/sd_forge_controlnet/tests/web_api/generation_test.py @@ -3,6 +3,9 @@ import pytest from .template import ( APITestTemplate, girl_img, + mask_img, + disable_in_cq, + get_model, ) @@ -83,3 +86,86 @@ def test_save_map(gen_type, save_map): unit_overrides={"save_detected_map": save_map}, input_image=girl_img, ).exec(expected_output_num=2 if save_map else 1) + + +@disable_in_cq +def test_masked_controlnet_txt2img(): + assert APITestTemplate( + f"test_masked_controlnet_txt2img", + "txt2img", + payload_overrides={}, + unit_overrides={ + "image": girl_img, + "mask_image": mask_img, + }, + ).exec() + + +@disable_in_cq +def test_masked_controlnet_img2img(): + assert APITestTemplate( + f"test_masked_controlnet_img2img", + "img2img", + payload_overrides={ + "init_images": [girl_img], + }, + # Note: Currently you must give ControlNet unit input image to specify + # mask. + # TODO: Fix this for img2img. + unit_overrides={ + "image": girl_img, + "mask_image": mask_img, + }, + ).exec() + + +@disable_in_cq +def test_txt2img_inpaint(): + assert APITestTemplate( + "txt2img_inpaint", + "txt2img", + payload_overrides={}, + unit_overrides={ + "image": girl_img, + "mask_image": mask_img, + "model": get_model("v11p_sd15_inpaint"), + "module": "inpaint_only", + }, + ).exec() + + +@disable_in_cq +def test_img2img_inpaint(): + assert APITestTemplate( + "img2img_inpaint", + "img2img", + payload_overrides={ + "init_images": [girl_img], + "mask": mask_img, + }, + unit_overrides={ + "model": get_model("v11p_sd15_inpaint"), + "module": "inpaint_only", + }, + ).exec() + + +# Currently failing. +# TODO Fix lama outpaint. +@disable_in_cq +def test_lama_outpaint(): + assert APITestTemplate( + "txt2img_lama_outpaint", + "txt2img", + payload_overrides={ + "width": 768, + "height": 768, + }, + # Outpaint should not need a mask. + unit_overrides={ + "image": girl_img, + "model": get_model("v11p_sd15_inpaint"), + "module": "inpaint_only+lama", + "resize_mode": "Resize and Fill", # OUTER_FIT + }, + ).exec() diff --git a/extensions-builtin/sd_forge_controlnet/tests/web_api/template.py b/extensions-builtin/sd_forge_controlnet/tests/web_api/template.py index 56d33a2e..5129e541 100644 --- a/extensions-builtin/sd_forge_controlnet/tests/web_api/template.py +++ b/extensions-builtin/sd_forge_controlnet/tests/web_api/template.py @@ -252,14 +252,13 @@ default_unit = { "enabled": True, "guidance_end": 1, "guidance_start": 0, - "low_vram": False, "pixel_perfect": True, "processor_res": 512, "resize_mode": 1, "threshold_a": 64, "threshold_b": 64, "weight": 1, - "module": "None", + "module": "canny", "model": get_model("sd15_canny"), }