mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-02-22 07:43:58 +00:00
288 lines
10 KiB
Python
288 lines
10 KiB
Python
|
|
import spaces
|
|
import os
|
|
import gradio as gr
|
|
import gc
|
|
|
|
try:
|
|
import moviepy.editor as mp
|
|
got_mp = True
|
|
except:
|
|
got_mp = False
|
|
|
|
from loadimg import load_img
|
|
from transformers import AutoModelForImageSegmentation
|
|
import torch
|
|
from torchvision import transforms
|
|
|
|
import glob
|
|
import pathlib
|
|
from PIL import Image
|
|
import numpy
|
|
|
|
|
|
transform_image = None
|
|
birefnet = None
|
|
|
|
def load_model(model):
|
|
global birefnet
|
|
birefnet = None
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
|
|
birefnet = AutoModelForImageSegmentation.from_pretrained(
|
|
"ZhengPeng7/"+model, trust_remote_code=True
|
|
)
|
|
birefnet.eval()
|
|
birefnet.half()
|
|
|
|
spaces.automatically_move_to_gpu_when_forward(birefnet)
|
|
|
|
with spaces.capture_gpu_object() as birefnet_gpu_obj:
|
|
load_model("BiRefNet_HR")
|
|
|
|
def common_setup(w, h):
|
|
global transform_image
|
|
|
|
transform_image = transforms.Compose(
|
|
[
|
|
transforms.Resize((w, h)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
|
]
|
|
)
|
|
|
|
|
|
@spaces.GPU(gpu_objects=[birefnet_gpu_obj], manual_load=True)
|
|
def process(image, save_flat, bg_colour):
|
|
im = load_img(image, output_type="pil")
|
|
im = im.convert("RGB")
|
|
image_size = im.size
|
|
image = load_img(im)
|
|
input_image = transform_image(image).unsqueeze(0).to(spaces.gpu).to(torch.float16)
|
|
# Prediction
|
|
with torch.no_grad():
|
|
preds = birefnet(input_image)[-1].sigmoid().cpu()
|
|
pred = preds[0].squeeze()
|
|
pred_pil = transforms.ToPILImage()(pred)
|
|
mask = pred_pil.resize(image_size)
|
|
image.putalpha(mask)
|
|
|
|
if save_flat:
|
|
bg_colour += "FF"
|
|
colour_rgb = tuple(int(bg_colour[i:i+2], 16) for i in (1, 3, 5, 7))
|
|
background = Image.new("RGBA", image_size, colour_rgb)
|
|
image = Image.alpha_composite(background, image)
|
|
image = image.convert("RGB")
|
|
|
|
return image
|
|
|
|
# video processing based on https://huggingface.co/spaces/brokerrobin/video-background-removal/blob/main/app.py
|
|
@spaces.GPU(gpu_objects=[birefnet_gpu_obj], manual_load=True)
|
|
def video_process(video, bg_colour):
|
|
# Load the video using moviepy
|
|
video = mp.VideoFileClip(video)
|
|
|
|
fps = video.fps
|
|
|
|
# Extract audio from the video
|
|
audio = video.audio
|
|
|
|
# Extract frames at the specified FPS
|
|
frames = video.iter_frames(fps=fps)
|
|
|
|
# Process each frame for background removal
|
|
processed_frames = []
|
|
|
|
for i, frame in enumerate(frames):
|
|
print (f"birefnet [video]: frame {i+1}", end='\r', flush=True)
|
|
|
|
image = Image.fromarray(frame)
|
|
|
|
if i == 0:
|
|
image_size = image.size
|
|
|
|
colour_rgb = tuple(int(bg_colour[i:i+2], 16) for i in (1, 3, 5))
|
|
background = Image.new("RGBA", image_size, colour_rgb + (255,))
|
|
|
|
input_image = transform_image(image).unsqueeze(0).to(spaces.gpu).to(torch.float16)
|
|
# Prediction
|
|
with torch.no_grad():
|
|
preds = birefnet(input_image)[-1].sigmoid().cpu()
|
|
pred = preds[0].squeeze()
|
|
pred_pil = transforms.ToPILImage()(pred)
|
|
mask = pred_pil.resize(image_size)
|
|
|
|
# Apply mask and composite
|
|
image.putalpha(mask)
|
|
processed_image = Image.alpha_composite(background, image)
|
|
|
|
processed_frames.append(numpy.array(processed_image))
|
|
|
|
# Create a new video from the processed frames
|
|
processed_video = mp.ImageSequenceClip(processed_frames, fps=fps)
|
|
|
|
# Add the original audio back to the processed video
|
|
processed_video = processed_video.set_audio(audio)
|
|
|
|
# Save the processed video using modified original filename (goes to gradio temp)
|
|
filename, _ = os.path.splitext(video.filename)
|
|
filename += "-birefnet.mp4"
|
|
processed_video.write_videofile(filename, codec="libx264")
|
|
|
|
return filename
|
|
|
|
|
|
@spaces.GPU(gpu_objects=[birefnet_gpu_obj], manual_load=True)
|
|
def batch_process(input_folder, output_folder, save_png, save_flat, bg_colour):
|
|
# Ensure output folder exists
|
|
os.makedirs(output_folder, exist_ok=True)
|
|
|
|
# Supported image extensions
|
|
image_extensions = ['.jpg', '.jpeg', '.jfif', '.png', '.bmp', '.webp', ".avif"]
|
|
|
|
# Collect all image files from input folder
|
|
input_images = []
|
|
for ext in image_extensions:
|
|
input_images.extend(glob.glob(os.path.join(input_folder, f'*{ext}')))
|
|
|
|
if save_flat:
|
|
bg_colour += "FF"
|
|
colour_rgb = tuple(int(bg_colour[i:i+2], 16) for i in (1, 3, 5, 7))
|
|
# Process each image
|
|
processed_images = []
|
|
for i, image_path in enumerate(input_images):
|
|
print (f"birefnet [batch]: image {i+1}", end='\r', flush=True)
|
|
try:
|
|
# Load image
|
|
im = load_img(image_path, output_type="pil")
|
|
im = im.convert("RGB")
|
|
image_size = im.size
|
|
image = load_img(im)
|
|
|
|
# Prepare image for processing
|
|
input_image = transform_image(image).unsqueeze(0).to(spaces.gpu).to(torch.float16)
|
|
|
|
# Prediction
|
|
with torch.no_grad():
|
|
preds = birefnet(input_image)[-1].sigmoid().cpu()
|
|
|
|
pred = preds[0].squeeze()
|
|
pred_pil = transforms.ToPILImage()(pred)
|
|
mask = pred_pil.resize(image_size)
|
|
|
|
# Apply mask
|
|
image.putalpha(mask)
|
|
|
|
# Save processed image
|
|
output_filename = os.path.join(output_folder, f"{pathlib.Path(image_path).name}")
|
|
|
|
if save_flat:
|
|
background = Image.new("RGBA", image_size, colour_rgb)
|
|
image = Image.alpha_composite(background, image)
|
|
image = image.convert("RGB")
|
|
elif output_filename.lower().endswith(".jpg") or output_filename.lower().endswith(".jpeg"):
|
|
# jpegs don't support alpha channel, so add .png extension (not change, to avoid potential overwrites)
|
|
output_filename += ".png"
|
|
if save_png and not output_filename.lower().endswith(".png"):
|
|
output_filename += ".png"
|
|
|
|
image.save(output_filename)
|
|
|
|
processed_images.append(output_filename)
|
|
|
|
except Exception as e:
|
|
print(f"Error processing {image_path}: {str(e)}")
|
|
|
|
return processed_images
|
|
|
|
|
|
def unload():
|
|
global birefnet, transform_image
|
|
birefnet = None
|
|
transform_image = None
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
css = """
|
|
.gradio-container {
|
|
max-width: 1280px !important;
|
|
}
|
|
footer {
|
|
display: none !important;
|
|
}
|
|
"""
|
|
|
|
with gr.Blocks(css=css, analytics_enabled=False) as demo:
|
|
gr.Markdown("# birefnet for background removal")
|
|
|
|
with gr.Tab("image"):
|
|
with gr.Row():
|
|
with gr.Column():
|
|
image = gr.Image(label="Upload an image", type='pil', height=584)
|
|
go_image = gr.Button("Remove background")
|
|
with gr.Column():
|
|
result1 = gr.Image(label="birefnet", type="pil", height=544)
|
|
|
|
with gr.Tab("URL"):
|
|
with gr.Row():
|
|
with gr.Column():
|
|
text = gr.Textbox(label="URL to image, or local path to image", max_lines=1)
|
|
go_text = gr.Button("Remove background")
|
|
with gr.Column():
|
|
result2 = gr.Image(label="birefnet", type="pil", height=544)
|
|
|
|
if got_mp:
|
|
with gr.Tab("video"):
|
|
with gr.Row():
|
|
with gr.Column():
|
|
video = gr.Video(label="Upload a video", height=584)
|
|
go_video = gr.Button("Remove background")
|
|
with gr.Column():
|
|
result4 = gr.Video(label="birefnet", height=544, show_share_button=False)
|
|
|
|
with gr.Tab("batch"):
|
|
with gr.Row():
|
|
with gr.Column():
|
|
input_dir = gr.Textbox(label="Input folder path", max_lines=1)
|
|
output_dir = gr.Textbox(label="Output folder path (save images will overwrite)", max_lines=1)
|
|
always_png = gr.Checkbox(label="Always save as PNG", value=True)
|
|
|
|
go_batch = gr.Button("Remove background(s)")
|
|
with gr.Column():
|
|
result3 = gr.File(label="Processed image(s)", type="filepath", file_count="multiple")
|
|
|
|
with gr.Tab("options"):
|
|
gr.Markdown("*HR* : high resolution; *matting* : better with transparency; *lite* : faster.")
|
|
model = gr.Dropdown(label="Model (download on selection, see console for progress)",
|
|
choices=["BiRefNet_512x512", "BiRefNet", "BiRefNet_HR", "BiRefNet-matting", "BiRefNet_HR-matting", "BiRefNet_lite", "BiRefNet_lite-2K", "BiRefNet-portrait", "BiRefNet-COD", "BiRefNet-DIS5K", "BiRefNet-DIS5k-TR_TEs", "BiRefNet-HRSOD"], value="BiRefNet_HR", type="value")
|
|
|
|
gr.Markdown("Regular models trained at 1024 \u00D7 1024; HR models trained at 2048 \u00D7 2048; 2K model trained at 2560 \u00D7 1440.")
|
|
gr.Markdown("Greater processing image size will typically give more accurate results, but also requires more VRAM (shared memory works well).")
|
|
with gr.Row():
|
|
proc_sizeW = gr.Slider(label="birefnet processing image width",
|
|
minimum=256, maximum=2560, value=2048, step=32)
|
|
proc_sizeH = gr.Slider(label="birefnet processing image height",
|
|
minimum=256, maximum=2048, value=2048, step=32)
|
|
with gr.Row():
|
|
save_flat = gr.Checkbox(label="Save flat (no mask)", value=False)
|
|
bg_colour = gr.ColorPicker(label="Background colour for saving flat, and video", value="#00FF00", visible=True, interactive=True)
|
|
|
|
model.change(fn=load_model, inputs=model, outputs=None)
|
|
|
|
gr.Markdown("### https://github.com/ZhengPeng7/BiRefNet\n### https://huggingface.co/ZhengPeng7")
|
|
|
|
go_image.click(fn=common_setup, inputs=[proc_sizeW, proc_sizeH]).then(fn=process, inputs=[image, save_flat, bg_colour], outputs=result1)
|
|
go_text.click( fn=common_setup, inputs=[proc_sizeW, proc_sizeH]).then(fn=process, inputs=[text, save_flat, bg_colour], outputs=result2)
|
|
if got_mp:
|
|
go_video.click(fn=common_setup, inputs=[proc_sizeW, proc_sizeH]).then(
|
|
fn=video_process, inputs=[video, bg_colour], outputs=result4)
|
|
go_batch.click(fn=common_setup, inputs=[proc_sizeW, proc_sizeH]).then(
|
|
fn=batch_process, inputs=[input_dir, output_dir, always_png, save_flat, bg_colour], outputs=result3)
|
|
|
|
demo.unload(unload)
|
|
|
|
if __name__ == "__main__":
|
|
demo.launch(inbrowser=True)
|