From 264634624eb595bc052f41a8ff021295c1f7bc41 Mon Sep 17 00:00:00 2001 From: continue revolution Date: Mon, 5 Feb 2024 23:55:59 -0600 Subject: [PATCH] mask batch, working, infotext broken --- .../sd_forge_controlnet/lib_controlnet/infotext.py | 4 +++- .../sd_forge_controlnet/scripts/controlnet.py | 6 +++--- ldm_patched/modules/controlnet.py | 4 +++- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/infotext.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/infotext.py index 9e61cd6e..207c442a 100644 --- a/extensions-builtin/sd_forge_controlnet/lib_controlnet/infotext.py +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/infotext.py @@ -37,7 +37,9 @@ def serialize_unit(unit: external_code.ControlNetUnit) -> str: "generated_image", "mask_image", "batch_input_gallery", - "batch_image_dir" + "batch_mask_gallery", + "batch_image_dir", + "batch_mask_dir", ) log_value = { diff --git a/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py b/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py index 6ca2390c..a4a59726 100644 --- a/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py +++ b/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py @@ -155,7 +155,7 @@ class ControlNetForForgeOfficial(scripts.Script): mask_path = unit.batch_mask_gallery[idx]['name'] else: mask_path = unit.batch_mask_gallery[0]['name'] - mask = np.ascontiguousarray(cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)) + mask = np.ascontiguousarray(cv2.imread(mask_path)[:, :, ::-1]).copy() if img is not None: image_list.append([img, mask]) return image_list, resize_mode @@ -175,7 +175,7 @@ class ControlNetForForgeOfficial(scripts.Script): else: mask_path = unit.batch_mask_dir[0] mask_path = os.path.join(unit.batch_mask_dir, mask_path) - mask = np.ascontiguousarray(cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)) + mask = np.ascontiguousarray(cv2.imread(mask_path)[:, :, ::-1]).copy() if img is not None: image_list.append([img, mask]) return image_list, resize_mode @@ -340,7 +340,7 @@ class ControlNetForForgeOfficial(scripts.Script): for input_mask in control_masks: fill_border = preprocessor.fill_mask_with_one_when_resize_and_fill control_mask = crop_and_resize_image(input_mask, resize_mode, h, w, fill_border) - p.extra_result_images.append(params.control_mask) + p.extra_result_images.append(control_mask) control_mask = numpy_to_pytorch(control_mask).movedim(-1, 1)[:, :1] params.control_mask.append(control_mask) diff --git a/ldm_patched/modules/controlnet.py b/ldm_patched/modules/controlnet.py index 1935655b..9414f8e7 100644 --- a/ldm_patched/modules/controlnet.py +++ b/ldm_patched/modules/controlnet.py @@ -87,7 +87,9 @@ def compute_controlnet_weighting(control, cnet): final_weight = final_weight * sigma_weight * frame_weight if isinstance(advanced_mask_weighting, torch.Tensor): - control_signal = control_signal * torch.nn.functional.interpolate(advanced_mask_weighting.to(control_signal), size=(B, H, W), mode='bilinear') + if control_signal.shape[0] == 2 * advanced_mask_weighting.shape[0]: + advanced_mask_weighting = advanced_mask_weighting.repeat(2, 1, 1, 1) + control_signal = control_signal * torch.nn.functional.interpolate(advanced_mask_weighting.to(control_signal), size=(H, W), mode='bilinear') control[k][i] = control_signal * final_weight[:, None, None, None]