mask batch, not working

This commit is contained in:
continue revolution
2024-02-05 23:01:57 -06:00
parent c5b51b35fb
commit d60eccb8ba
4 changed files with 100 additions and 31 deletions

View File

@@ -325,23 +325,44 @@ class ControlNetUiGroup(object):
)
with gr.Tab(label="Batch Folder") as self.batch_tab:
self.batch_image_dir = gr.Textbox(
label="Input Directory",
placeholder="Input directory path to the control images.",
elem_id=f"{elem_id_tabname}_{tabname}_batch_image_dir",
)
with gr.Row():
self.batch_image_dir = gr.Textbox(
label="Input Directory",
placeholder="Input directory path to the control images.",
elem_id=f"{elem_id_tabname}_{tabname}_batch_image_dir",
)
self.batch_mask_dir = gr.Textbox(
label="Mask Directory",
placeholder="Mask directory path to the control images.",
elem_id=f"{elem_id_tabname}_{tabname}_batch_mask_dir",
visible=False,
)
with gr.Tab(label="Batch Upload") as self.merge_tab:
self.batch_input_gallery = gr.Gallery(
columns=[4], rows=[2], object_fit="contain", height="auto"
)
with gr.Row():
self.merge_upload_button = gr.UploadButton(
"Upload Images",
file_types=["image"],
file_count="multiple",
)
self.merge_clear_button = gr.Button("Clear Images")
with gr.Column():
self.batch_input_gallery = gr.Gallery(
columns=[4], rows=[2], object_fit="contain", height="auto", label="Images"
)
with gr.Row():
self.merge_upload_button = gr.UploadButton(
"Upload Images",
file_types=["image"],
file_count="multiple",
)
self.merge_clear_button = gr.Button("Clear Images")
with gr.Group(visible=False, elem_classes=["cnet-mask-gallery-group"]) as self.batch_mask_gallery_group:
with gr.Column():
self.batch_mask_gallery = gr.Gallery(
columns=[4], rows=[2], object_fit="contain", height="auto", label="Masks"
)
with gr.Row():
self.mask_merge_upload_button = gr.UploadButton(
"Upload Masks",
file_types=["image"],
file_count="multiple",
)
self.mask_merge_clear_button = gr.Button("Clear Masks")
if self.photopea:
self.photopea.attach_photopea_output(self.generated_image)
@@ -585,7 +606,9 @@ class ControlNetUiGroup(object):
self.input_mode,
self.use_preview_as_input,
self.batch_image_dir,
self.batch_mask_dir,
self.batch_input_gallery,
self.batch_mask_gallery,
self.generated_image,
self.mask_image,
self.enabled,
@@ -961,16 +984,19 @@ class ControlNetUiGroup(object):
def on_checkbox_click(checked: bool, canvas_height: int, canvas_width: int):
if not checked:
# Clear mask_image if unchecked.
return gr.update(visible=False), gr.update(value=None)
return gr.update(visible=False), gr.update(value=None), gr.update(value=None, visible=False), \
gr.update(visible=False), gr.update(value=None)
else:
# Init an empty canvas the same size as the generation target.
empty_canvas = np.zeros(shape=(canvas_height, canvas_width, 3), dtype=np.uint8)
return gr.update(visible=True), gr.update(value=empty_canvas)
return gr.update(visible=True), gr.update(value=empty_canvas), gr.update(visible=True), \
gr.update(visible=True), gr.update()
self.mask_upload.change(
fn=on_checkbox_click,
inputs=[self.mask_upload, self.height_slider, self.width_slider],
outputs=[self.mask_image_group, self.mask_image],
outputs=[self.mask_image_group, self.mask_image, self.batch_mask_dir,
self.batch_mask_gallery_group, self.batch_mask_gallery],
show_progress=False,
)
@@ -1064,6 +1090,11 @@ class ControlNetUiGroup(object):
inputs=[self.update_unit_counter],
outputs=[self.update_unit_counter],
)
self.mask_merge_clear_button.click(
fn=lambda: [],
inputs=[],
outputs=[self.batch_mask_gallery],
)
def upload_file(files, current_files):
return {file_d["name"] for file_d in current_files} | {
@@ -1080,6 +1111,12 @@ class ControlNetUiGroup(object):
inputs=[self.update_unit_counter],
outputs=[self.update_unit_counter],
)
self.mask_merge_upload_button.upload(
upload_file,
inputs=[self.mask_merge_upload_button, self.batch_mask_gallery],
outputs=[self.batch_mask_gallery],
queue=False,
)
return
def register_core_callbacks(self):

View File

@@ -151,7 +151,9 @@ class UiControlNetUnit:
input_mode: InputMode = InputMode.SIMPLE
use_preview_as_input: bool = False,
batch_image_dir: str = '',
batch_mask_dir: str = '',
batch_input_gallery: list = [],
batch_mask_gallery: list = [],
generated_image: Optional[np.ndarray] = None,
mask_image: Optional[np.ndarray] = None,
enabled: bool = True

View File

@@ -145,11 +145,17 @@ class ControlNetForForgeOfficial(scripts.Script):
if unit.input_mode == external_code.InputMode.MERGE:
image_list = []
for item in unit.batch_input_gallery:
for idx, item in enumerate(unit.batch_input_gallery):
img_path = item['name']
logger.info(f'Try to read image: {img_path}')
img = np.ascontiguousarray(cv2.imread(img_path)[:, :, ::-1]).copy()
mask = None
if len(unit.batch_mask_gallery) > 0:
if len(unit.batch_mask_gallery) >= len(unit.batch_input_gallery):
mask_path = unit.batch_mask_gallery[idx]['name']
else:
mask_path = unit.batch_mask_gallery[0]['name']
mask = np.ascontiguousarray(cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE))
if img is not None:
image_list.append([img, mask])
return image_list, resize_mode
@@ -157,12 +163,19 @@ class ControlNetForForgeOfficial(scripts.Script):
if unit.input_mode == external_code.InputMode.BATCH:
image_list = []
image_extensions = ['.jpg', '.jpeg', '.png', '.bmp']
for filename in os.listdir(unit.batch_image_dir):
for idx, filename in enumerate(os.listdir(unit.batch_image_dir)):
if any(filename.lower().endswith(ext) for ext in image_extensions):
img_path = os.path.join(unit.batch_image_dir, filename)
logger.info(f'Try to read image: {img_path}')
img = np.ascontiguousarray(cv2.imread(img_path)[:, :, ::-1]).copy()
mask = None
if len(unit.batch_mask_dir) > 0:
if len(unit.batch_mask_dir) >= len(unit.batch_image_dir):
mask_path = unit.batch_mask_dir[idx]
else:
mask_path = unit.batch_mask_dir[0]
mask_path = os.path.join(unit.batch_mask_dir, mask_path)
mask = np.ascontiguousarray(cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE))
if img is not None:
image_list.append([img, mask])
return image_list, resize_mode
@@ -252,11 +265,15 @@ class ControlNetForForgeOfficial(scripts.Script):
input_list, resize_mode = self.get_input_data(p, unit, preprocessor)
preprocessor_outputs = []
control_masks = []
preprocessor_output_is_image = False
input_image, input_mask = input_list[0]
preprocessor_output = None
for input_image, input_mask in input_list:
def optional_tqdm(iterable, use_tqdm):
from tqdm import tqdm
return tqdm(iterable) if use_tqdm else iterable
for input_image, input_mask in optional_tqdm(input_list, len(input_list) > 1):
# p.extra_result_images.append(input_image)
if unit.pixel_perfect:
@@ -284,12 +301,15 @@ class ControlNetForForgeOfficial(scripts.Script):
preprocessor_output_is_image = judge_image_type(preprocessor_output)
if input_mask is not None:
control_masks.append(input_mask)
if len(input_list) > 1 and not preprocessor_output_is_image:
logger.info('Batch wise input only support controlnet, control-lora, and t2i adapters!')
break
alignment_indices = [i % len(preprocessor_outputs) for i in range(p.batch_size)]
if preprocessor_output_is_image:
alignment_indices = [i % len(preprocessor_outputs) for i in range(p.batch_size)]
params.control_cond = []
params.control_cond_for_hr_fix = []
@@ -313,16 +333,26 @@ class ControlNetForForgeOfficial(scripts.Script):
params.control_cond_for_hr_fix = preprocessor_output
p.extra_result_images.append(input_image)
if input_mask is not None:
fill_border = preprocessor.fill_mask_with_one_when_resize_and_fill
params.control_mask = crop_and_resize_image(input_mask, resize_mode, h, w, fill_border)
p.extra_result_images.append(params.control_mask)
params.control_mask = numpy_to_pytorch(params.control_mask).movedim(-1, 1)[:, :1]
if len(control_masks) > 0:
params.control_mask = []
params.control_mask_for_hr_fix = []
for input_mask in control_masks:
fill_border = preprocessor.fill_mask_with_one_when_resize_and_fill
control_mask = crop_and_resize_image(input_mask, resize_mode, h, w, fill_border)
p.extra_result_images.append(params.control_mask)
control_mask = numpy_to_pytorch(control_mask).movedim(-1, 1)[:, :1]
params.control_mask.append(control_mask)
if has_high_res_fix:
control_mask_for_hr_fix = crop_and_resize_image(input_mask, resize_mode, hr_y, hr_x, fill_border)
p.extra_result_images.append(control_mask_for_hr_fix)
control_mask_for_hr_fix = numpy_to_pytorch(control_mask_for_hr_fix).movedim(-1, 1)[:, :1]
params.control_mask_for_hr_fix.append(control_mask_for_hr_fix)
params.control_mask = torch.cat(params.control_mask, dim=0)[alignment_indices].contiguous()
if has_high_res_fix:
params.control_mask_for_hr_fix = crop_and_resize_image(input_mask, resize_mode, hr_y, hr_x, fill_border)
p.extra_result_images.append(params.control_mask_for_hr_fix)
params.control_mask_for_hr_fix = numpy_to_pytorch(params.control_mask_for_hr_fix).movedim(-1, 1)[:, :1]
params.control_mask_for_hr_fix = torch.cat(params.control_mask_for_hr_fix, dim=0)[alignment_indices].contiguous()
else:
params.control_mask_for_hr_fix = params.control_mask

View File

@@ -87,7 +87,7 @@ def compute_controlnet_weighting(control, cnet):
final_weight = final_weight * sigma_weight * frame_weight
if isinstance(advanced_mask_weighting, torch.Tensor):
control_signal = control_signal * torch.nn.functional.interpolate(advanced_mask_weighting.to(control_signal), size=(H, W), mode='bilinear')
control_signal = control_signal * torch.nn.functional.interpolate(advanced_mask_weighting.to(control_signal), size=(B, H, W), mode='bilinear')
control[k][i] = control_signal * final_weight[:, None, None, None]