mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-02-02 14:27:27 +00:00
batch mask (#44)
* mask batch, not working
* mask batch, working, infotext broken
* try remove old codes
* set CUDA_VISIBLE_DEVICES with args
* Revert "try remove old codes"
This reverts commit 63c527c373.
* Update controlnet_ui_group.py
* readme
* 🐛 Fix infotext
---------
Co-authored-by: lllyasviel <lyuminzhang@outlook.com>
Co-authored-by: huchenlei <chenlei.hu@mail.utoronto.ca>
This commit is contained in:
@@ -666,7 +666,7 @@ ControlNet and TiledVAE are integrated, and you should uninstall these two exten
|
||||
sd-webui-controlnet
|
||||
multidiffusion-upscaler-for-automatic1111
|
||||
|
||||
Note that **animatediff** is under construction by [continue-revolution](https://github.com/continue-revolution) at [sd-forge-animatediff](https://github.com/continue-revolution/sd-forge-animatediff). (continue-revolution original words: "basic features (t2v, prompt travel, inf v2v) have been proven to work well, motion lora, cn v2v still under construction and may be finished in a week, and we can mention motion brush")
|
||||
Note that **AnimateDiff** is under construction by [continue-revolution](https://github.com/continue-revolution) at [sd-webui-animatediff forge/master branch](https://github.com/continue-revolution/sd-webui-animatediff/tree/forge/master) and [sd-forge-animatediff](https://github.com/continue-revolution/sd-forge-animatediff) (they are in sync). (continue-revolution original words: "basic features (prompt travel, inf t2v) have been proven to work well, motion lora, cn v2v still under construction and may be finished in a week, and we can mention motion brush")
|
||||
|
||||
Other extensions should work without problems, like:
|
||||
|
||||
|
||||
@@ -180,10 +180,11 @@ class ControlNetUiGroup(object):
|
||||
|
||||
# Note: All gradio elements declared in `render` will be defined as member variable.
|
||||
# Update counter to trigger a force update of UiControlNetUnit.
|
||||
# This is useful when a field with no event subscriber available changes.
|
||||
# e.g. gr.Gallery, gr.State, etc.
|
||||
# dummy_gradio_update_trigger is useful when a field with no event subscriber available changes.
|
||||
# e.g. gr.Gallery, gr.State, etc. After an update to gr.State / gr.Gallery, please increment
|
||||
# this counter to trigger a sync update of UiControlNetUnit.
|
||||
self.dummy_gradio_update_trigger = None
|
||||
self.enabled = None
|
||||
self.update_unit_counter = None
|
||||
self.upload_tab = None
|
||||
self.image = None
|
||||
self.generated_image_group = None
|
||||
@@ -251,7 +252,7 @@ class ControlNetUiGroup(object):
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
self.update_unit_counter = gr.Number(value=0, visible=False)
|
||||
self.dummy_gradio_update_trigger = gr.Number(value=0, visible=False)
|
||||
self.openpose_editor = OpenposeEditor()
|
||||
|
||||
with gr.Group(visible=not self.is_img2img) as self.image_upload_panel:
|
||||
@@ -325,23 +326,44 @@ class ControlNetUiGroup(object):
|
||||
)
|
||||
|
||||
with gr.Tab(label="Batch Folder") as self.batch_tab:
|
||||
self.batch_image_dir = gr.Textbox(
|
||||
label="Input Directory",
|
||||
placeholder="Input directory path to the control images.",
|
||||
elem_id=f"{elem_id_tabname}_{tabname}_batch_image_dir",
|
||||
)
|
||||
with gr.Row():
|
||||
self.batch_image_dir = gr.Textbox(
|
||||
label="Input Directory",
|
||||
placeholder="Input directory path to the control images.",
|
||||
elem_id=f"{elem_id_tabname}_{tabname}_batch_image_dir",
|
||||
)
|
||||
self.batch_mask_dir = gr.Textbox(
|
||||
label="Mask Directory",
|
||||
placeholder="Mask directory path to the control images.",
|
||||
elem_id=f"{elem_id_tabname}_{tabname}_batch_mask_dir",
|
||||
visible=False,
|
||||
)
|
||||
|
||||
with gr.Tab(label="Batch Upload") as self.merge_tab:
|
||||
self.batch_input_gallery = gr.Gallery(
|
||||
columns=[4], rows=[2], object_fit="contain", height="auto"
|
||||
)
|
||||
with gr.Row():
|
||||
self.merge_upload_button = gr.UploadButton(
|
||||
"Upload Images",
|
||||
file_types=["image"],
|
||||
file_count="multiple",
|
||||
)
|
||||
self.merge_clear_button = gr.Button("Clear Images")
|
||||
with gr.Column():
|
||||
self.batch_input_gallery = gr.Gallery(
|
||||
columns=[4], rows=[2], object_fit="contain", height="auto", label="Images"
|
||||
)
|
||||
with gr.Row():
|
||||
self.merge_upload_button = gr.UploadButton(
|
||||
"Upload Images",
|
||||
file_types=["image"],
|
||||
file_count="multiple",
|
||||
)
|
||||
self.merge_clear_button = gr.Button("Clear Images")
|
||||
with gr.Group(visible=False, elem_classes=["cnet-mask-gallery-group"]) as self.batch_mask_gallery_group:
|
||||
with gr.Column():
|
||||
self.batch_mask_gallery = gr.Gallery(
|
||||
columns=[4], rows=[2], object_fit="contain", height="auto", label="Masks"
|
||||
)
|
||||
with gr.Row():
|
||||
self.mask_merge_upload_button = gr.UploadButton(
|
||||
"Upload Masks",
|
||||
file_types=["image"],
|
||||
file_count="multiple",
|
||||
)
|
||||
self.mask_merge_clear_button = gr.Button("Clear Masks")
|
||||
|
||||
if self.photopea:
|
||||
self.photopea.attach_photopea_output(self.generated_image)
|
||||
@@ -585,7 +607,9 @@ class ControlNetUiGroup(object):
|
||||
self.input_mode,
|
||||
self.use_preview_as_input,
|
||||
self.batch_image_dir,
|
||||
self.batch_mask_dir,
|
||||
self.batch_input_gallery,
|
||||
self.batch_mask_gallery,
|
||||
self.generated_image,
|
||||
self.mask_image,
|
||||
self.enabled,
|
||||
@@ -604,7 +628,7 @@ class ControlNetUiGroup(object):
|
||||
)
|
||||
|
||||
unit = gr.State(self.default_unit)
|
||||
for comp in unit_args + (self.update_unit_counter,):
|
||||
for comp in unit_args + (self.dummy_gradio_update_trigger,):
|
||||
event_subscribers = []
|
||||
if hasattr(comp, "edit"):
|
||||
event_subscribers.append(comp.edit)
|
||||
@@ -961,16 +985,19 @@ class ControlNetUiGroup(object):
|
||||
def on_checkbox_click(checked: bool, canvas_height: int, canvas_width: int):
|
||||
if not checked:
|
||||
# Clear mask_image if unchecked.
|
||||
return gr.update(visible=False), gr.update(value=None)
|
||||
return gr.update(visible=False), gr.update(value=None), gr.update(value=None, visible=False), \
|
||||
gr.update(visible=False), gr.update(value=None)
|
||||
else:
|
||||
# Init an empty canvas the same size as the generation target.
|
||||
empty_canvas = np.zeros(shape=(canvas_height, canvas_width, 3), dtype=np.uint8)
|
||||
return gr.update(visible=True), gr.update(value=empty_canvas)
|
||||
return gr.update(visible=True), gr.update(value=empty_canvas), gr.update(visible=True), \
|
||||
gr.update(visible=True), gr.update()
|
||||
|
||||
self.mask_upload.change(
|
||||
fn=on_checkbox_click,
|
||||
inputs=[self.mask_upload, self.height_slider, self.width_slider],
|
||||
outputs=[self.mask_image_group, self.mask_image],
|
||||
outputs=[self.mask_image_group, self.mask_image, self.batch_mask_dir,
|
||||
self.batch_mask_gallery_group, self.batch_mask_gallery],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
@@ -1061,8 +1088,13 @@ class ControlNetUiGroup(object):
|
||||
outputs=[self.batch_input_gallery],
|
||||
).then(
|
||||
fn=lambda x: gr.update(value=x + 1),
|
||||
inputs=[self.update_unit_counter],
|
||||
outputs=[self.update_unit_counter],
|
||||
inputs=[self.dummy_gradio_update_trigger],
|
||||
outputs=[self.dummy_gradio_update_trigger],
|
||||
)
|
||||
self.mask_merge_clear_button.click(
|
||||
fn=lambda: [],
|
||||
inputs=[],
|
||||
outputs=[self.batch_mask_gallery],
|
||||
)
|
||||
|
||||
def upload_file(files, current_files):
|
||||
@@ -1077,8 +1109,14 @@ class ControlNetUiGroup(object):
|
||||
queue=False,
|
||||
).then(
|
||||
fn=lambda x: gr.update(value=x + 1),
|
||||
inputs=[self.update_unit_counter],
|
||||
outputs=[self.update_unit_counter],
|
||||
inputs=[self.dummy_gradio_update_trigger],
|
||||
outputs=[self.dummy_gradio_update_trigger],
|
||||
)
|
||||
self.mask_merge_upload_button.upload(
|
||||
upload_file,
|
||||
inputs=[self.mask_merge_upload_button, self.batch_mask_gallery],
|
||||
outputs=[self.batch_mask_gallery],
|
||||
queue=False,
|
||||
)
|
||||
return
|
||||
|
||||
@@ -1105,7 +1143,7 @@ class ControlNetUiGroup(object):
|
||||
self.type_filter,
|
||||
*[
|
||||
getattr(self, key)
|
||||
for key in vars(external_code.ControlNetUnit()).keys()
|
||||
for key in external_code.ControlNetUnit.infotext_fields()
|
||||
],
|
||||
)
|
||||
if self.is_img2img:
|
||||
|
||||
@@ -117,7 +117,7 @@ class ControlNetPresetUI(object):
|
||||
gr.update(visible=False),
|
||||
*(
|
||||
(gr.skip(),)
|
||||
* (len(vars(external_code.ControlNetUnit()).keys()) + 1)
|
||||
* (len(external_code.ControlNetUnit.infotext_fields()) + 1)
|
||||
),
|
||||
)
|
||||
|
||||
@@ -140,7 +140,7 @@ class ControlNetPresetUI(object):
|
||||
gr.update(visible=False),
|
||||
*(
|
||||
(gr.skip(),)
|
||||
* (len(vars(external_code.ControlNetUnit()).keys()) + 1)
|
||||
* (len(external_code.ControlNetUnit.infotext_fields()) + 1)
|
||||
),
|
||||
)
|
||||
|
||||
@@ -166,7 +166,8 @@ class ControlNetPresetUI(object):
|
||||
gr.update(value=new_control_type),
|
||||
*[
|
||||
gr.update(value=value) if value is not None else gr.update()
|
||||
for value in vars(unit).values()
|
||||
for field in external_code.ControlNetUnit.infotext_fields()
|
||||
for value in (getattr(unit, field),)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@@ -151,7 +151,9 @@ class UiControlNetUnit:
|
||||
input_mode: InputMode = InputMode.SIMPLE
|
||||
use_preview_as_input: bool = False,
|
||||
batch_image_dir: str = '',
|
||||
batch_mask_dir: str = '',
|
||||
batch_input_gallery: list = [],
|
||||
batch_mask_gallery: list = [],
|
||||
generated_image: Optional[np.ndarray] = None,
|
||||
mask_image: Optional[np.ndarray] = None,
|
||||
enabled: bool = True
|
||||
@@ -168,6 +170,27 @@ class UiControlNetUnit:
|
||||
pixel_perfect: bool = False
|
||||
control_mode: Union[ControlMode, int, str] = ControlMode.BALANCED
|
||||
|
||||
@staticmethod
|
||||
def infotext_fields():
|
||||
"""Fields that should be included in infotext.
|
||||
You should define a Gradio element with exact same name in ControlNetUiGroup
|
||||
as well, so that infotext can wire the value to correct field when pasting
|
||||
infotext.
|
||||
"""
|
||||
return (
|
||||
"module",
|
||||
"model",
|
||||
"weight",
|
||||
"resize_mode",
|
||||
"processor_res",
|
||||
"threshold_a",
|
||||
"threshold_b",
|
||||
"guidance_start",
|
||||
"guidance_end",
|
||||
"pixel_perfect",
|
||||
"control_mode",
|
||||
)
|
||||
|
||||
|
||||
# Backward Compatible
|
||||
ControlNetUnit = UiControlNetUnit
|
||||
|
||||
@@ -29,21 +29,10 @@ def parse_value(value: str) -> Union[str, float, int, bool]:
|
||||
|
||||
|
||||
def serialize_unit(unit: external_code.ControlNetUnit) -> str:
|
||||
excluded_fields = (
|
||||
"image",
|
||||
"enabled",
|
||||
"input_mode",
|
||||
"use_preview_as_input",
|
||||
"generated_image",
|
||||
"mask_image",
|
||||
"batch_input_gallery",
|
||||
"batch_image_dir"
|
||||
)
|
||||
|
||||
log_value = {
|
||||
field_to_displaytext(field): getattr(unit, field)
|
||||
for field in vars(external_code.ControlNetUnit()).keys()
|
||||
if field not in excluded_fields and getattr(unit, field) != -1
|
||||
for field in external_code.ControlNetUnit.infotext_fields()
|
||||
if getattr(unit, field) != -1
|
||||
# Note: exclude hidden slider values.
|
||||
}
|
||||
if not all("," not in str(v) and ":" not in str(v) for v in log_value.values()):
|
||||
@@ -83,12 +72,8 @@ class Infotext(object):
|
||||
iocomponents.
|
||||
"""
|
||||
unit_prefix = Infotext.unit_prefix(unit_index)
|
||||
for field in vars(external_code.ControlNetUnit()).keys():
|
||||
# Exclude image for infotext.
|
||||
if field == "image":
|
||||
continue
|
||||
|
||||
# Every field in ControlNetUnit should have a cooresponding
|
||||
for field in external_code.ControlNetUnit.infotext_fields():
|
||||
# Every field in ControlNetUnit should have a corresponding
|
||||
# IOComponent in ControlNetUiGroup.
|
||||
io_component = getattr(uigroup, field)
|
||||
component_locator = f"{unit_prefix} {field}"
|
||||
|
||||
@@ -145,11 +145,17 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
|
||||
if unit.input_mode == external_code.InputMode.MERGE:
|
||||
image_list = []
|
||||
for item in unit.batch_input_gallery:
|
||||
for idx, item in enumerate(unit.batch_input_gallery):
|
||||
img_path = item['name']
|
||||
logger.info(f'Try to read image: {img_path}')
|
||||
img = np.ascontiguousarray(cv2.imread(img_path)[:, :, ::-1]).copy()
|
||||
mask = None
|
||||
if len(unit.batch_mask_gallery) > 0:
|
||||
if len(unit.batch_mask_gallery) >= len(unit.batch_input_gallery):
|
||||
mask_path = unit.batch_mask_gallery[idx]['name']
|
||||
else:
|
||||
mask_path = unit.batch_mask_gallery[0]['name']
|
||||
mask = np.ascontiguousarray(cv2.imread(mask_path)[:, :, ::-1]).copy()
|
||||
if img is not None:
|
||||
image_list.append([img, mask])
|
||||
return image_list, resize_mode
|
||||
@@ -157,12 +163,19 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
if unit.input_mode == external_code.InputMode.BATCH:
|
||||
image_list = []
|
||||
image_extensions = ['.jpg', '.jpeg', '.png', '.bmp']
|
||||
for filename in os.listdir(unit.batch_image_dir):
|
||||
for idx, filename in enumerate(os.listdir(unit.batch_image_dir)):
|
||||
if any(filename.lower().endswith(ext) for ext in image_extensions):
|
||||
img_path = os.path.join(unit.batch_image_dir, filename)
|
||||
logger.info(f'Try to read image: {img_path}')
|
||||
img = np.ascontiguousarray(cv2.imread(img_path)[:, :, ::-1]).copy()
|
||||
mask = None
|
||||
if len(unit.batch_mask_dir) > 0:
|
||||
if len(unit.batch_mask_dir) >= len(unit.batch_image_dir):
|
||||
mask_path = unit.batch_mask_dir[idx]
|
||||
else:
|
||||
mask_path = unit.batch_mask_dir[0]
|
||||
mask_path = os.path.join(unit.batch_mask_dir, mask_path)
|
||||
mask = np.ascontiguousarray(cv2.imread(mask_path)[:, :, ::-1]).copy()
|
||||
if img is not None:
|
||||
image_list.append([img, mask])
|
||||
return image_list, resize_mode
|
||||
@@ -252,11 +265,15 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
|
||||
input_list, resize_mode = self.get_input_data(p, unit, preprocessor)
|
||||
preprocessor_outputs = []
|
||||
control_masks = []
|
||||
preprocessor_output_is_image = False
|
||||
input_image, input_mask = input_list[0]
|
||||
preprocessor_output = None
|
||||
|
||||
for input_image, input_mask in input_list:
|
||||
def optional_tqdm(iterable, use_tqdm):
|
||||
from tqdm import tqdm
|
||||
return tqdm(iterable) if use_tqdm else iterable
|
||||
|
||||
for input_image, input_mask in optional_tqdm(input_list, len(input_list) > 1):
|
||||
# p.extra_result_images.append(input_image)
|
||||
|
||||
if unit.pixel_perfect:
|
||||
@@ -284,12 +301,15 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
|
||||
preprocessor_output_is_image = judge_image_type(preprocessor_output)
|
||||
|
||||
if input_mask is not None:
|
||||
control_masks.append(input_mask)
|
||||
|
||||
if len(input_list) > 1 and not preprocessor_output_is_image:
|
||||
logger.info('Batch wise input only support controlnet, control-lora, and t2i adapters!')
|
||||
break
|
||||
|
||||
alignment_indices = [i % len(preprocessor_outputs) for i in range(p.batch_size)]
|
||||
if preprocessor_output_is_image:
|
||||
alignment_indices = [i % len(preprocessor_outputs) for i in range(p.batch_size)]
|
||||
params.control_cond = []
|
||||
params.control_cond_for_hr_fix = []
|
||||
|
||||
@@ -313,16 +333,26 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
params.control_cond_for_hr_fix = preprocessor_output
|
||||
p.extra_result_images.append(input_image)
|
||||
|
||||
if input_mask is not None:
|
||||
fill_border = preprocessor.fill_mask_with_one_when_resize_and_fill
|
||||
params.control_mask = crop_and_resize_image(input_mask, resize_mode, h, w, fill_border)
|
||||
p.extra_result_images.append(params.control_mask)
|
||||
params.control_mask = numpy_to_pytorch(params.control_mask).movedim(-1, 1)[:, :1]
|
||||
if len(control_masks) > 0:
|
||||
params.control_mask = []
|
||||
params.control_mask_for_hr_fix = []
|
||||
|
||||
for input_mask in control_masks:
|
||||
fill_border = preprocessor.fill_mask_with_one_when_resize_and_fill
|
||||
control_mask = crop_and_resize_image(input_mask, resize_mode, h, w, fill_border)
|
||||
p.extra_result_images.append(control_mask)
|
||||
control_mask = numpy_to_pytorch(control_mask).movedim(-1, 1)[:, :1]
|
||||
params.control_mask.append(control_mask)
|
||||
|
||||
if has_high_res_fix:
|
||||
control_mask_for_hr_fix = crop_and_resize_image(input_mask, resize_mode, hr_y, hr_x, fill_border)
|
||||
p.extra_result_images.append(control_mask_for_hr_fix)
|
||||
control_mask_for_hr_fix = numpy_to_pytorch(control_mask_for_hr_fix).movedim(-1, 1)[:, :1]
|
||||
params.control_mask_for_hr_fix.append(control_mask_for_hr_fix)
|
||||
|
||||
params.control_mask = torch.cat(params.control_mask, dim=0)[alignment_indices].contiguous()
|
||||
if has_high_res_fix:
|
||||
params.control_mask_for_hr_fix = crop_and_resize_image(input_mask, resize_mode, hr_y, hr_x, fill_border)
|
||||
p.extra_result_images.append(params.control_mask_for_hr_fix)
|
||||
params.control_mask_for_hr_fix = numpy_to_pytorch(params.control_mask_for_hr_fix).movedim(-1, 1)[:, :1]
|
||||
params.control_mask_for_hr_fix = torch.cat(params.control_mask_for_hr_fix, dim=0)[alignment_indices].contiguous()
|
||||
else:
|
||||
params.control_mask_for_hr_fix = params.control_mask
|
||||
|
||||
|
||||
@@ -87,6 +87,8 @@ def compute_controlnet_weighting(control, cnet):
|
||||
final_weight = final_weight * sigma_weight * frame_weight
|
||||
|
||||
if isinstance(advanced_mask_weighting, torch.Tensor):
|
||||
if control_signal.shape[0] == 2 * advanced_mask_weighting.shape[0]:
|
||||
advanced_mask_weighting = advanced_mask_weighting.repeat(2, 1, 1, 1)
|
||||
control_signal = control_signal * torch.nn.functional.interpolate(advanced_mask_weighting.to(control_signal), size=(H, W), mode='bilinear')
|
||||
|
||||
control[k][i] = control_signal * final_weight[:, None, None, None]
|
||||
|
||||
Reference in New Issue
Block a user