From d60eccb8ba8988c40a7b48e3213041d9a019453a Mon Sep 17 00:00:00 2001 From: continue revolution Date: Mon, 5 Feb 2024 23:01:57 -0600 Subject: [PATCH] mask batch, not working --- .../controlnet_ui/controlnet_ui_group.py | 71 ++++++++++++++----- .../lib_controlnet/external_code.py | 2 + .../sd_forge_controlnet/scripts/controlnet.py | 56 +++++++++++---- ldm_patched/modules/controlnet.py | 2 +- 4 files changed, 100 insertions(+), 31 deletions(-) diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/controlnet_ui_group.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/controlnet_ui_group.py index 796c00ca..c59e519b 100644 --- a/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/controlnet_ui_group.py +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/controlnet_ui_group.py @@ -325,23 +325,44 @@ class ControlNetUiGroup(object): ) with gr.Tab(label="Batch Folder") as self.batch_tab: - self.batch_image_dir = gr.Textbox( - label="Input Directory", - placeholder="Input directory path to the control images.", - elem_id=f"{elem_id_tabname}_{tabname}_batch_image_dir", - ) + with gr.Row(): + self.batch_image_dir = gr.Textbox( + label="Input Directory", + placeholder="Input directory path to the control images.", + elem_id=f"{elem_id_tabname}_{tabname}_batch_image_dir", + ) + self.batch_mask_dir = gr.Textbox( + label="Mask Directory", + placeholder="Mask directory path to the control images.", + elem_id=f"{elem_id_tabname}_{tabname}_batch_mask_dir", + visible=False, + ) with gr.Tab(label="Batch Upload") as self.merge_tab: - self.batch_input_gallery = gr.Gallery( - columns=[4], rows=[2], object_fit="contain", height="auto" - ) with gr.Row(): - self.merge_upload_button = gr.UploadButton( - "Upload Images", - file_types=["image"], - file_count="multiple", - ) - self.merge_clear_button = gr.Button("Clear Images") + with gr.Column(): + self.batch_input_gallery = gr.Gallery( + columns=[4], rows=[2], object_fit="contain", height="auto", label="Images" + ) + with gr.Row(): + self.merge_upload_button = gr.UploadButton( + "Upload Images", + file_types=["image"], + file_count="multiple", + ) + self.merge_clear_button = gr.Button("Clear Images") + with gr.Group(visible=False, elem_classes=["cnet-mask-gallery-group"]) as self.batch_mask_gallery_group: + with gr.Column(): + self.batch_mask_gallery = gr.Gallery( + columns=[4], rows=[2], object_fit="contain", height="auto", label="Masks" + ) + with gr.Row(): + self.mask_merge_upload_button = gr.UploadButton( + "Upload Masks", + file_types=["image"], + file_count="multiple", + ) + self.mask_merge_clear_button = gr.Button("Clear Masks") if self.photopea: self.photopea.attach_photopea_output(self.generated_image) @@ -585,7 +606,9 @@ class ControlNetUiGroup(object): self.input_mode, self.use_preview_as_input, self.batch_image_dir, + self.batch_mask_dir, self.batch_input_gallery, + self.batch_mask_gallery, self.generated_image, self.mask_image, self.enabled, @@ -961,16 +984,19 @@ class ControlNetUiGroup(object): def on_checkbox_click(checked: bool, canvas_height: int, canvas_width: int): if not checked: # Clear mask_image if unchecked. - return gr.update(visible=False), gr.update(value=None) + return gr.update(visible=False), gr.update(value=None), gr.update(value=None, visible=False), \ + gr.update(visible=False), gr.update(value=None) else: # Init an empty canvas the same size as the generation target. empty_canvas = np.zeros(shape=(canvas_height, canvas_width, 3), dtype=np.uint8) - return gr.update(visible=True), gr.update(value=empty_canvas) + return gr.update(visible=True), gr.update(value=empty_canvas), gr.update(visible=True), \ + gr.update(visible=True), gr.update() self.mask_upload.change( fn=on_checkbox_click, inputs=[self.mask_upload, self.height_slider, self.width_slider], - outputs=[self.mask_image_group, self.mask_image], + outputs=[self.mask_image_group, self.mask_image, self.batch_mask_dir, + self.batch_mask_gallery_group, self.batch_mask_gallery], show_progress=False, ) @@ -1064,6 +1090,11 @@ class ControlNetUiGroup(object): inputs=[self.update_unit_counter], outputs=[self.update_unit_counter], ) + self.mask_merge_clear_button.click( + fn=lambda: [], + inputs=[], + outputs=[self.batch_mask_gallery], + ) def upload_file(files, current_files): return {file_d["name"] for file_d in current_files} | { @@ -1080,6 +1111,12 @@ class ControlNetUiGroup(object): inputs=[self.update_unit_counter], outputs=[self.update_unit_counter], ) + self.mask_merge_upload_button.upload( + upload_file, + inputs=[self.mask_merge_upload_button, self.batch_mask_gallery], + outputs=[self.batch_mask_gallery], + queue=False, + ) return def register_core_callbacks(self): diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/external_code.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/external_code.py index 03dd7dda..dfe3dcc3 100644 --- a/extensions-builtin/sd_forge_controlnet/lib_controlnet/external_code.py +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/external_code.py @@ -151,7 +151,9 @@ class UiControlNetUnit: input_mode: InputMode = InputMode.SIMPLE use_preview_as_input: bool = False, batch_image_dir: str = '', + batch_mask_dir: str = '', batch_input_gallery: list = [], + batch_mask_gallery: list = [], generated_image: Optional[np.ndarray] = None, mask_image: Optional[np.ndarray] = None, enabled: bool = True diff --git a/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py b/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py index 13281aa5..6ca2390c 100644 --- a/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py +++ b/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py @@ -145,11 +145,17 @@ class ControlNetForForgeOfficial(scripts.Script): if unit.input_mode == external_code.InputMode.MERGE: image_list = [] - for item in unit.batch_input_gallery: + for idx, item in enumerate(unit.batch_input_gallery): img_path = item['name'] logger.info(f'Try to read image: {img_path}') img = np.ascontiguousarray(cv2.imread(img_path)[:, :, ::-1]).copy() mask = None + if len(unit.batch_mask_gallery) > 0: + if len(unit.batch_mask_gallery) >= len(unit.batch_input_gallery): + mask_path = unit.batch_mask_gallery[idx]['name'] + else: + mask_path = unit.batch_mask_gallery[0]['name'] + mask = np.ascontiguousarray(cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)) if img is not None: image_list.append([img, mask]) return image_list, resize_mode @@ -157,12 +163,19 @@ class ControlNetForForgeOfficial(scripts.Script): if unit.input_mode == external_code.InputMode.BATCH: image_list = [] image_extensions = ['.jpg', '.jpeg', '.png', '.bmp'] - for filename in os.listdir(unit.batch_image_dir): + for idx, filename in enumerate(os.listdir(unit.batch_image_dir)): if any(filename.lower().endswith(ext) for ext in image_extensions): img_path = os.path.join(unit.batch_image_dir, filename) logger.info(f'Try to read image: {img_path}') img = np.ascontiguousarray(cv2.imread(img_path)[:, :, ::-1]).copy() mask = None + if len(unit.batch_mask_dir) > 0: + if len(unit.batch_mask_dir) >= len(unit.batch_image_dir): + mask_path = unit.batch_mask_dir[idx] + else: + mask_path = unit.batch_mask_dir[0] + mask_path = os.path.join(unit.batch_mask_dir, mask_path) + mask = np.ascontiguousarray(cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)) if img is not None: image_list.append([img, mask]) return image_list, resize_mode @@ -252,11 +265,15 @@ class ControlNetForForgeOfficial(scripts.Script): input_list, resize_mode = self.get_input_data(p, unit, preprocessor) preprocessor_outputs = [] + control_masks = [] preprocessor_output_is_image = False - input_image, input_mask = input_list[0] preprocessor_output = None - for input_image, input_mask in input_list: + def optional_tqdm(iterable, use_tqdm): + from tqdm import tqdm + return tqdm(iterable) if use_tqdm else iterable + + for input_image, input_mask in optional_tqdm(input_list, len(input_list) > 1): # p.extra_result_images.append(input_image) if unit.pixel_perfect: @@ -284,12 +301,15 @@ class ControlNetForForgeOfficial(scripts.Script): preprocessor_output_is_image = judge_image_type(preprocessor_output) + if input_mask is not None: + control_masks.append(input_mask) + if len(input_list) > 1 and not preprocessor_output_is_image: logger.info('Batch wise input only support controlnet, control-lora, and t2i adapters!') break + alignment_indices = [i % len(preprocessor_outputs) for i in range(p.batch_size)] if preprocessor_output_is_image: - alignment_indices = [i % len(preprocessor_outputs) for i in range(p.batch_size)] params.control_cond = [] params.control_cond_for_hr_fix = [] @@ -313,16 +333,26 @@ class ControlNetForForgeOfficial(scripts.Script): params.control_cond_for_hr_fix = preprocessor_output p.extra_result_images.append(input_image) - if input_mask is not None: - fill_border = preprocessor.fill_mask_with_one_when_resize_and_fill - params.control_mask = crop_and_resize_image(input_mask, resize_mode, h, w, fill_border) - p.extra_result_images.append(params.control_mask) - params.control_mask = numpy_to_pytorch(params.control_mask).movedim(-1, 1)[:, :1] + if len(control_masks) > 0: + params.control_mask = [] + params.control_mask_for_hr_fix = [] + for input_mask in control_masks: + fill_border = preprocessor.fill_mask_with_one_when_resize_and_fill + control_mask = crop_and_resize_image(input_mask, resize_mode, h, w, fill_border) + p.extra_result_images.append(params.control_mask) + control_mask = numpy_to_pytorch(control_mask).movedim(-1, 1)[:, :1] + params.control_mask.append(control_mask) + + if has_high_res_fix: + control_mask_for_hr_fix = crop_and_resize_image(input_mask, resize_mode, hr_y, hr_x, fill_border) + p.extra_result_images.append(control_mask_for_hr_fix) + control_mask_for_hr_fix = numpy_to_pytorch(control_mask_for_hr_fix).movedim(-1, 1)[:, :1] + params.control_mask_for_hr_fix.append(control_mask_for_hr_fix) + + params.control_mask = torch.cat(params.control_mask, dim=0)[alignment_indices].contiguous() if has_high_res_fix: - params.control_mask_for_hr_fix = crop_and_resize_image(input_mask, resize_mode, hr_y, hr_x, fill_border) - p.extra_result_images.append(params.control_mask_for_hr_fix) - params.control_mask_for_hr_fix = numpy_to_pytorch(params.control_mask_for_hr_fix).movedim(-1, 1)[:, :1] + params.control_mask_for_hr_fix = torch.cat(params.control_mask_for_hr_fix, dim=0)[alignment_indices].contiguous() else: params.control_mask_for_hr_fix = params.control_mask diff --git a/ldm_patched/modules/controlnet.py b/ldm_patched/modules/controlnet.py index 9ddd3c6b..1935655b 100644 --- a/ldm_patched/modules/controlnet.py +++ b/ldm_patched/modules/controlnet.py @@ -87,7 +87,7 @@ def compute_controlnet_weighting(control, cnet): final_weight = final_weight * sigma_weight * frame_weight if isinstance(advanced_mask_weighting, torch.Tensor): - control_signal = control_signal * torch.nn.functional.interpolate(advanced_mask_weighting.to(control_signal), size=(H, W), mode='bilinear') + control_signal = control_signal * torch.nn.functional.interpolate(advanced_mask_weighting.to(control_signal), size=(B, H, W), mode='bilinear') control[k][i] = control_signal * final_weight[:, None, None, None]