diff --git a/README.md b/README.md index c2cb3169..f8f72149 100644 --- a/README.md +++ b/README.md @@ -666,7 +666,7 @@ ControlNet and TiledVAE are integrated, and you should uninstall these two exten sd-webui-controlnet multidiffusion-upscaler-for-automatic1111 -Note that **animatediff** is under construction by [continue-revolution](https://github.com/continue-revolution) at [sd-forge-animatediff](https://github.com/continue-revolution/sd-forge-animatediff). (continue-revolution original words: "basic features (t2v, prompt travel, inf v2v) have been proven to work well, motion lora, cn v2v still under construction and may be finished in a week, and we can mention motion brush") +Note that **AnimateDiff** is under construction by [continue-revolution](https://github.com/continue-revolution) at [sd-webui-animatediff forge/master branch](https://github.com/continue-revolution/sd-webui-animatediff/tree/forge/master) and [sd-forge-animatediff](https://github.com/continue-revolution/sd-forge-animatediff) (they are in sync). (continue-revolution original words: "basic features (prompt travel, inf t2v) have been proven to work well, motion lora, cn v2v still under construction and may be finished in a week, and we can mention motion brush") Other extensions should work without problems, like: diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/controlnet_ui_group.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/controlnet_ui_group.py index 796c00ca..ee8a2cf2 100644 --- a/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/controlnet_ui_group.py +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/controlnet_ui_group.py @@ -180,10 +180,11 @@ class ControlNetUiGroup(object): # Note: All gradio elements declared in `render` will be defined as member variable. # Update counter to trigger a force update of UiControlNetUnit. - # This is useful when a field with no event subscriber available changes. - # e.g. gr.Gallery, gr.State, etc. + # dummy_gradio_update_trigger is useful when a field with no event subscriber available changes. + # e.g. gr.Gallery, gr.State, etc. After an update to gr.State / gr.Gallery, please increment + # this counter to trigger a sync update of UiControlNetUnit. + self.dummy_gradio_update_trigger = None self.enabled = None - self.update_unit_counter = None self.upload_tab = None self.image = None self.generated_image_group = None @@ -251,7 +252,7 @@ class ControlNetUiGroup(object): Returns: None """ - self.update_unit_counter = gr.Number(value=0, visible=False) + self.dummy_gradio_update_trigger = gr.Number(value=0, visible=False) self.openpose_editor = OpenposeEditor() with gr.Group(visible=not self.is_img2img) as self.image_upload_panel: @@ -325,23 +326,44 @@ class ControlNetUiGroup(object): ) with gr.Tab(label="Batch Folder") as self.batch_tab: - self.batch_image_dir = gr.Textbox( - label="Input Directory", - placeholder="Input directory path to the control images.", - elem_id=f"{elem_id_tabname}_{tabname}_batch_image_dir", - ) + with gr.Row(): + self.batch_image_dir = gr.Textbox( + label="Input Directory", + placeholder="Input directory path to the control images.", + elem_id=f"{elem_id_tabname}_{tabname}_batch_image_dir", + ) + self.batch_mask_dir = gr.Textbox( + label="Mask Directory", + placeholder="Mask directory path to the control images.", + elem_id=f"{elem_id_tabname}_{tabname}_batch_mask_dir", + visible=False, + ) with gr.Tab(label="Batch Upload") as self.merge_tab: - self.batch_input_gallery = gr.Gallery( - columns=[4], rows=[2], object_fit="contain", height="auto" - ) with gr.Row(): - self.merge_upload_button = gr.UploadButton( - "Upload Images", - file_types=["image"], - file_count="multiple", - ) - self.merge_clear_button = gr.Button("Clear Images") + with gr.Column(): + self.batch_input_gallery = gr.Gallery( + columns=[4], rows=[2], object_fit="contain", height="auto", label="Images" + ) + with gr.Row(): + self.merge_upload_button = gr.UploadButton( + "Upload Images", + file_types=["image"], + file_count="multiple", + ) + self.merge_clear_button = gr.Button("Clear Images") + with gr.Group(visible=False, elem_classes=["cnet-mask-gallery-group"]) as self.batch_mask_gallery_group: + with gr.Column(): + self.batch_mask_gallery = gr.Gallery( + columns=[4], rows=[2], object_fit="contain", height="auto", label="Masks" + ) + with gr.Row(): + self.mask_merge_upload_button = gr.UploadButton( + "Upload Masks", + file_types=["image"], + file_count="multiple", + ) + self.mask_merge_clear_button = gr.Button("Clear Masks") if self.photopea: self.photopea.attach_photopea_output(self.generated_image) @@ -585,7 +607,9 @@ class ControlNetUiGroup(object): self.input_mode, self.use_preview_as_input, self.batch_image_dir, + self.batch_mask_dir, self.batch_input_gallery, + self.batch_mask_gallery, self.generated_image, self.mask_image, self.enabled, @@ -604,7 +628,7 @@ class ControlNetUiGroup(object): ) unit = gr.State(self.default_unit) - for comp in unit_args + (self.update_unit_counter,): + for comp in unit_args + (self.dummy_gradio_update_trigger,): event_subscribers = [] if hasattr(comp, "edit"): event_subscribers.append(comp.edit) @@ -961,16 +985,19 @@ class ControlNetUiGroup(object): def on_checkbox_click(checked: bool, canvas_height: int, canvas_width: int): if not checked: # Clear mask_image if unchecked. - return gr.update(visible=False), gr.update(value=None) + return gr.update(visible=False), gr.update(value=None), gr.update(value=None, visible=False), \ + gr.update(visible=False), gr.update(value=None) else: # Init an empty canvas the same size as the generation target. empty_canvas = np.zeros(shape=(canvas_height, canvas_width, 3), dtype=np.uint8) - return gr.update(visible=True), gr.update(value=empty_canvas) + return gr.update(visible=True), gr.update(value=empty_canvas), gr.update(visible=True), \ + gr.update(visible=True), gr.update() self.mask_upload.change( fn=on_checkbox_click, inputs=[self.mask_upload, self.height_slider, self.width_slider], - outputs=[self.mask_image_group, self.mask_image], + outputs=[self.mask_image_group, self.mask_image, self.batch_mask_dir, + self.batch_mask_gallery_group, self.batch_mask_gallery], show_progress=False, ) @@ -1061,8 +1088,13 @@ class ControlNetUiGroup(object): outputs=[self.batch_input_gallery], ).then( fn=lambda x: gr.update(value=x + 1), - inputs=[self.update_unit_counter], - outputs=[self.update_unit_counter], + inputs=[self.dummy_gradio_update_trigger], + outputs=[self.dummy_gradio_update_trigger], + ) + self.mask_merge_clear_button.click( + fn=lambda: [], + inputs=[], + outputs=[self.batch_mask_gallery], ) def upload_file(files, current_files): @@ -1077,8 +1109,14 @@ class ControlNetUiGroup(object): queue=False, ).then( fn=lambda x: gr.update(value=x + 1), - inputs=[self.update_unit_counter], - outputs=[self.update_unit_counter], + inputs=[self.dummy_gradio_update_trigger], + outputs=[self.dummy_gradio_update_trigger], + ) + self.mask_merge_upload_button.upload( + upload_file, + inputs=[self.mask_merge_upload_button, self.batch_mask_gallery], + outputs=[self.batch_mask_gallery], + queue=False, ) return @@ -1105,7 +1143,7 @@ class ControlNetUiGroup(object): self.type_filter, *[ getattr(self, key) - for key in vars(external_code.ControlNetUnit()).keys() + for key in external_code.ControlNetUnit.infotext_fields() ], ) if self.is_img2img: diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/preset.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/preset.py index 831bc93e..064e4ca6 100644 --- a/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/preset.py +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/preset.py @@ -117,7 +117,7 @@ class ControlNetPresetUI(object): gr.update(visible=False), *( (gr.skip(),) - * (len(vars(external_code.ControlNetUnit()).keys()) + 1) + * (len(external_code.ControlNetUnit.infotext_fields()) + 1) ), ) @@ -140,7 +140,7 @@ class ControlNetPresetUI(object): gr.update(visible=False), *( (gr.skip(),) - * (len(vars(external_code.ControlNetUnit()).keys()) + 1) + * (len(external_code.ControlNetUnit.infotext_fields()) + 1) ), ) @@ -166,7 +166,8 @@ class ControlNetPresetUI(object): gr.update(value=new_control_type), *[ gr.update(value=value) if value is not None else gr.update() - for value in vars(unit).values() + for field in external_code.ControlNetUnit.infotext_fields() + for value in (getattr(unit, field),) ], ) diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/external_code.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/external_code.py index 03dd7dda..df572028 100644 --- a/extensions-builtin/sd_forge_controlnet/lib_controlnet/external_code.py +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/external_code.py @@ -151,7 +151,9 @@ class UiControlNetUnit: input_mode: InputMode = InputMode.SIMPLE use_preview_as_input: bool = False, batch_image_dir: str = '', + batch_mask_dir: str = '', batch_input_gallery: list = [], + batch_mask_gallery: list = [], generated_image: Optional[np.ndarray] = None, mask_image: Optional[np.ndarray] = None, enabled: bool = True @@ -168,6 +170,27 @@ class UiControlNetUnit: pixel_perfect: bool = False control_mode: Union[ControlMode, int, str] = ControlMode.BALANCED + @staticmethod + def infotext_fields(): + """Fields that should be included in infotext. + You should define a Gradio element with exact same name in ControlNetUiGroup + as well, so that infotext can wire the value to correct field when pasting + infotext. + """ + return ( + "module", + "model", + "weight", + "resize_mode", + "processor_res", + "threshold_a", + "threshold_b", + "guidance_start", + "guidance_end", + "pixel_perfect", + "control_mode", + ) + # Backward Compatible ControlNetUnit = UiControlNetUnit diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/infotext.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/infotext.py index 9e61cd6e..78c1daaf 100644 --- a/extensions-builtin/sd_forge_controlnet/lib_controlnet/infotext.py +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/infotext.py @@ -29,21 +29,10 @@ def parse_value(value: str) -> Union[str, float, int, bool]: def serialize_unit(unit: external_code.ControlNetUnit) -> str: - excluded_fields = ( - "image", - "enabled", - "input_mode", - "use_preview_as_input", - "generated_image", - "mask_image", - "batch_input_gallery", - "batch_image_dir" - ) - log_value = { field_to_displaytext(field): getattr(unit, field) - for field in vars(external_code.ControlNetUnit()).keys() - if field not in excluded_fields and getattr(unit, field) != -1 + for field in external_code.ControlNetUnit.infotext_fields() + if getattr(unit, field) != -1 # Note: exclude hidden slider values. } if not all("," not in str(v) and ":" not in str(v) for v in log_value.values()): @@ -83,12 +72,8 @@ class Infotext(object): iocomponents. """ unit_prefix = Infotext.unit_prefix(unit_index) - for field in vars(external_code.ControlNetUnit()).keys(): - # Exclude image for infotext. - if field == "image": - continue - - # Every field in ControlNetUnit should have a cooresponding + for field in external_code.ControlNetUnit.infotext_fields(): + # Every field in ControlNetUnit should have a corresponding # IOComponent in ControlNetUiGroup. io_component = getattr(uigroup, field) component_locator = f"{unit_prefix} {field}" diff --git a/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py b/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py index 13281aa5..a4a59726 100644 --- a/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py +++ b/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py @@ -145,11 +145,17 @@ class ControlNetForForgeOfficial(scripts.Script): if unit.input_mode == external_code.InputMode.MERGE: image_list = [] - for item in unit.batch_input_gallery: + for idx, item in enumerate(unit.batch_input_gallery): img_path = item['name'] logger.info(f'Try to read image: {img_path}') img = np.ascontiguousarray(cv2.imread(img_path)[:, :, ::-1]).copy() mask = None + if len(unit.batch_mask_gallery) > 0: + if len(unit.batch_mask_gallery) >= len(unit.batch_input_gallery): + mask_path = unit.batch_mask_gallery[idx]['name'] + else: + mask_path = unit.batch_mask_gallery[0]['name'] + mask = np.ascontiguousarray(cv2.imread(mask_path)[:, :, ::-1]).copy() if img is not None: image_list.append([img, mask]) return image_list, resize_mode @@ -157,12 +163,19 @@ class ControlNetForForgeOfficial(scripts.Script): if unit.input_mode == external_code.InputMode.BATCH: image_list = [] image_extensions = ['.jpg', '.jpeg', '.png', '.bmp'] - for filename in os.listdir(unit.batch_image_dir): + for idx, filename in enumerate(os.listdir(unit.batch_image_dir)): if any(filename.lower().endswith(ext) for ext in image_extensions): img_path = os.path.join(unit.batch_image_dir, filename) logger.info(f'Try to read image: {img_path}') img = np.ascontiguousarray(cv2.imread(img_path)[:, :, ::-1]).copy() mask = None + if len(unit.batch_mask_dir) > 0: + if len(unit.batch_mask_dir) >= len(unit.batch_image_dir): + mask_path = unit.batch_mask_dir[idx] + else: + mask_path = unit.batch_mask_dir[0] + mask_path = os.path.join(unit.batch_mask_dir, mask_path) + mask = np.ascontiguousarray(cv2.imread(mask_path)[:, :, ::-1]).copy() if img is not None: image_list.append([img, mask]) return image_list, resize_mode @@ -252,11 +265,15 @@ class ControlNetForForgeOfficial(scripts.Script): input_list, resize_mode = self.get_input_data(p, unit, preprocessor) preprocessor_outputs = [] + control_masks = [] preprocessor_output_is_image = False - input_image, input_mask = input_list[0] preprocessor_output = None - for input_image, input_mask in input_list: + def optional_tqdm(iterable, use_tqdm): + from tqdm import tqdm + return tqdm(iterable) if use_tqdm else iterable + + for input_image, input_mask in optional_tqdm(input_list, len(input_list) > 1): # p.extra_result_images.append(input_image) if unit.pixel_perfect: @@ -284,12 +301,15 @@ class ControlNetForForgeOfficial(scripts.Script): preprocessor_output_is_image = judge_image_type(preprocessor_output) + if input_mask is not None: + control_masks.append(input_mask) + if len(input_list) > 1 and not preprocessor_output_is_image: logger.info('Batch wise input only support controlnet, control-lora, and t2i adapters!') break + alignment_indices = [i % len(preprocessor_outputs) for i in range(p.batch_size)] if preprocessor_output_is_image: - alignment_indices = [i % len(preprocessor_outputs) for i in range(p.batch_size)] params.control_cond = [] params.control_cond_for_hr_fix = [] @@ -313,16 +333,26 @@ class ControlNetForForgeOfficial(scripts.Script): params.control_cond_for_hr_fix = preprocessor_output p.extra_result_images.append(input_image) - if input_mask is not None: - fill_border = preprocessor.fill_mask_with_one_when_resize_and_fill - params.control_mask = crop_and_resize_image(input_mask, resize_mode, h, w, fill_border) - p.extra_result_images.append(params.control_mask) - params.control_mask = numpy_to_pytorch(params.control_mask).movedim(-1, 1)[:, :1] + if len(control_masks) > 0: + params.control_mask = [] + params.control_mask_for_hr_fix = [] + for input_mask in control_masks: + fill_border = preprocessor.fill_mask_with_one_when_resize_and_fill + control_mask = crop_and_resize_image(input_mask, resize_mode, h, w, fill_border) + p.extra_result_images.append(control_mask) + control_mask = numpy_to_pytorch(control_mask).movedim(-1, 1)[:, :1] + params.control_mask.append(control_mask) + + if has_high_res_fix: + control_mask_for_hr_fix = crop_and_resize_image(input_mask, resize_mode, hr_y, hr_x, fill_border) + p.extra_result_images.append(control_mask_for_hr_fix) + control_mask_for_hr_fix = numpy_to_pytorch(control_mask_for_hr_fix).movedim(-1, 1)[:, :1] + params.control_mask_for_hr_fix.append(control_mask_for_hr_fix) + + params.control_mask = torch.cat(params.control_mask, dim=0)[alignment_indices].contiguous() if has_high_res_fix: - params.control_mask_for_hr_fix = crop_and_resize_image(input_mask, resize_mode, hr_y, hr_x, fill_border) - p.extra_result_images.append(params.control_mask_for_hr_fix) - params.control_mask_for_hr_fix = numpy_to_pytorch(params.control_mask_for_hr_fix).movedim(-1, 1)[:, :1] + params.control_mask_for_hr_fix = torch.cat(params.control_mask_for_hr_fix, dim=0)[alignment_indices].contiguous() else: params.control_mask_for_hr_fix = params.control_mask diff --git a/ldm_patched/modules/controlnet.py b/ldm_patched/modules/controlnet.py index 9ddd3c6b..9414f8e7 100644 --- a/ldm_patched/modules/controlnet.py +++ b/ldm_patched/modules/controlnet.py @@ -87,6 +87,8 @@ def compute_controlnet_weighting(control, cnet): final_weight = final_weight * sigma_weight * frame_weight if isinstance(advanced_mask_weighting, torch.Tensor): + if control_signal.shape[0] == 2 * advanced_mask_weighting.shape[0]: + advanced_mask_weighting = advanced_mask_weighting.repeat(2, 1, 1, 1) control_signal = control_signal * torch.nn.functional.interpolate(advanced_mask_weighting.to(control_signal), size=(H, W), mode='bilinear') control[k][i] = control_signal * final_weight[:, None, None, None]