mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-03-13 00:49:48 +00:00
birefnet Space: more models, video (#2648)
This commit is contained in:
@@ -4,6 +4,12 @@ import os
|
||||
import gradio as gr
|
||||
import gc
|
||||
|
||||
try:
|
||||
import moviepy.editor as mp
|
||||
got_mp = True
|
||||
except:
|
||||
got_mp = False
|
||||
|
||||
from loadimg import load_img
|
||||
from transformers import AutoModelForImageSegmentation
|
||||
import torch
|
||||
@@ -12,6 +18,7 @@ from torchvision import transforms
|
||||
import glob
|
||||
import pathlib
|
||||
from PIL import Image
|
||||
import numpy
|
||||
|
||||
|
||||
transform_image = None
|
||||
@@ -24,7 +31,7 @@ def load_model(model):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
birefnet = AutoModelForImageSegmentation.from_pretrained(
|
||||
model, trust_remote_code=True
|
||||
"ZhengPeng7/"+model, trust_remote_code=True
|
||||
)
|
||||
birefnet.eval()
|
||||
birefnet.half()
|
||||
@@ -32,14 +39,14 @@ def load_model(model):
|
||||
spaces.automatically_move_to_gpu_when_forward(birefnet)
|
||||
|
||||
with spaces.capture_gpu_object() as birefnet_gpu_obj:
|
||||
load_model("ZhengPeng7/BiRefNet_HR")
|
||||
load_model("BiRefNet_HR")
|
||||
|
||||
def common_setup(size):
|
||||
def common_setup(w, h):
|
||||
global transform_image
|
||||
|
||||
transform_image = transforms.Compose(
|
||||
[
|
||||
transforms.Resize((size, size)),
|
||||
transforms.Resize((w, h)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||||
]
|
||||
@@ -47,7 +54,7 @@ def common_setup(size):
|
||||
|
||||
|
||||
@spaces.GPU(gpu_objects=[birefnet_gpu_obj], manual_load=True)
|
||||
def process(image):
|
||||
def process(image, save_flat, bg_colour):
|
||||
im = load_img(image, output_type="pil")
|
||||
im = im.convert("RGB")
|
||||
image_size = im.size
|
||||
@@ -60,25 +67,92 @@ def process(image):
|
||||
pred_pil = transforms.ToPILImage()(pred)
|
||||
mask = pred_pil.resize(image_size)
|
||||
image.putalpha(mask)
|
||||
|
||||
if save_flat:
|
||||
bg_colour += "FF"
|
||||
colour_rgb = tuple(int(bg_colour[i:i+2], 16) for i in (1, 3, 5, 7))
|
||||
background = Image.new("RGBA", image_size, colour_rgb)
|
||||
image = Image.alpha_composite(background, image)
|
||||
image = image.convert("RGB")
|
||||
|
||||
return image
|
||||
|
||||
# video processing based on https://huggingface.co/spaces/brokerrobin/video-background-removal/blob/main/app.py
|
||||
@spaces.GPU(gpu_objects=[birefnet_gpu_obj], manual_load=True)
|
||||
def video_process(video, bg_colour):
|
||||
# Load the video using moviepy
|
||||
video = mp.VideoFileClip(video)
|
||||
|
||||
fps = video.fps
|
||||
|
||||
# Extract audio from the video
|
||||
audio = video.audio
|
||||
|
||||
# Extract frames at the specified FPS
|
||||
frames = video.iter_frames(fps=fps)
|
||||
|
||||
# Process each frame for background removal
|
||||
processed_frames = []
|
||||
|
||||
for i, frame in enumerate(frames):
|
||||
print (f"birefnet [video]: frame {i+1}", end='\r', flush=True)
|
||||
|
||||
image = Image.fromarray(frame)
|
||||
|
||||
if i == 0:
|
||||
image_size = image.size
|
||||
|
||||
colour_rgb = tuple(int(bg_colour[i:i+2], 16) for i in (1, 3, 5))
|
||||
background = Image.new("RGBA", image_size, colour_rgb + (255,))
|
||||
|
||||
input_image = transform_image(image).unsqueeze(0).to(spaces.gpu).to(torch.float16)
|
||||
# Prediction
|
||||
with torch.no_grad():
|
||||
preds = birefnet(input_image)[-1].sigmoid().cpu()
|
||||
pred = preds[0].squeeze()
|
||||
pred_pil = transforms.ToPILImage()(pred)
|
||||
mask = pred_pil.resize(image_size)
|
||||
|
||||
# Apply mask and composite
|
||||
image.putalpha(mask)
|
||||
processed_image = Image.alpha_composite(background, image)
|
||||
|
||||
processed_frames.append(numpy.array(processed_image))
|
||||
|
||||
# Create a new video from the processed frames
|
||||
processed_video = mp.ImageSequenceClip(processed_frames, fps=fps)
|
||||
|
||||
# Add the original audio back to the processed video
|
||||
processed_video = processed_video.set_audio(audio)
|
||||
|
||||
# Save the processed video using modified original filename (goes to gradio temp)
|
||||
filename, _ = os.path.splitext(video.filename)
|
||||
filename += "-birefnet.mp4"
|
||||
processed_video.write_videofile(filename, codec="libx264")
|
||||
|
||||
return filename
|
||||
|
||||
|
||||
@spaces.GPU(gpu_objects=[birefnet_gpu_obj], manual_load=True)
|
||||
def batch_process(input_folder, output_folder, save_png, save_flat):
|
||||
def batch_process(input_folder, output_folder, save_png, save_flat, bg_colour):
|
||||
# Ensure output folder exists
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
|
||||
# Supported image extensions
|
||||
image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.webp', ".avif"]
|
||||
image_extensions = ['.jpg', '.jpeg', '.jfif', '.png', '.bmp', '.webp', ".avif"]
|
||||
|
||||
# Collect all image files from input folder
|
||||
input_images = []
|
||||
for ext in image_extensions:
|
||||
input_images.extend(glob.glob(os.path.join(input_folder, f'*{ext}')))
|
||||
|
||||
|
||||
if save_flat:
|
||||
bg_colour += "FF"
|
||||
colour_rgb = tuple(int(bg_colour[i:i+2], 16) for i in (1, 3, 5, 7))
|
||||
# Process each image
|
||||
processed_images = []
|
||||
for image_path in input_images:
|
||||
for i, image_path in enumerate(input_images):
|
||||
print (f"birefnet [batch]: image {i+1}", end='\r', flush=True)
|
||||
try:
|
||||
# Load image
|
||||
im = load_img(image_path, output_type="pil")
|
||||
@@ -104,7 +178,7 @@ def batch_process(input_folder, output_folder, save_png, save_flat):
|
||||
output_filename = os.path.join(output_folder, f"{pathlib.Path(image_path).name}")
|
||||
|
||||
if save_flat:
|
||||
background = Image.new('RGBA', image.size, (255, 255, 255))
|
||||
background = Image.new("RGBA", image_size, colour_rgb)
|
||||
image = Image.alpha_composite(background, image)
|
||||
image = image.convert("RGB")
|
||||
elif output_filename.lower().endswith(".jpg") or output_filename.lower().endswith(".jpeg"):
|
||||
@@ -146,10 +220,10 @@ with gr.Blocks(css=css, analytics_enabled=False) as demo:
|
||||
with gr.Tab("image"):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
image = gr.Image(label="Upload an image", type='pil', height=616)
|
||||
image = gr.Image(label="Upload an image", type='pil', height=584)
|
||||
go_image = gr.Button("Remove background")
|
||||
with gr.Column():
|
||||
result1 = gr.Image(label="birefnet", type="pil", height=576)
|
||||
result1 = gr.Image(label="birefnet", type="pil", height=544)
|
||||
|
||||
with gr.Tab("URL"):
|
||||
with gr.Row():
|
||||
@@ -157,31 +231,55 @@ with gr.Blocks(css=css, analytics_enabled=False) as demo:
|
||||
text = gr.Textbox(label="URL to image, or local path to image", max_lines=1)
|
||||
go_text = gr.Button("Remove background")
|
||||
with gr.Column():
|
||||
result2 = gr.Image(label="birefnet", type="pil", height=576)
|
||||
result2 = gr.Image(label="birefnet", type="pil", height=544)
|
||||
|
||||
if got_mp:
|
||||
with gr.Tab("video"):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
video = gr.Video(label="Upload a video", height=584)
|
||||
go_video = gr.Button("Remove background")
|
||||
with gr.Column():
|
||||
result4 = gr.Video(label="birefnet", height=544, show_share_button=False)
|
||||
|
||||
with gr.Tab("batch"):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
input_dir = gr.Textbox(label="Input folder path", max_lines=1)
|
||||
output_dir = gr.Textbox(label="Output folder path (will overwrite)", max_lines=1)
|
||||
output_dir = gr.Textbox(label="Output folder path (save images will overwrite)", max_lines=1)
|
||||
always_png = gr.Checkbox(label="Always save as PNG", value=True)
|
||||
save_flat = gr.Checkbox(label="Save flat (no mask)", value=False)
|
||||
|
||||
go_batch = gr.Button("Remove background(s)")
|
||||
with gr.Column():
|
||||
result3 = gr.File(label="Processed image(s)", type="filepath", file_count="multiple")
|
||||
|
||||
with gr.Tab("options"):
|
||||
model = gr.Dropdown(label="Model",
|
||||
choices=["ZhengPeng7/BiRefNet", "ZhengPeng7/BiRefNet_HR"], value="ZhengPeng7/BiRefNet_HR", type="value")
|
||||
proc_size = gr.Dropdown(label="birefnet processing image size", info="1024: old model; 2048: HR model - more accurate, uses more VRAM (shared memory works well)",
|
||||
choices=[1024, 1536, 2048], value=2048)
|
||||
|
||||
gr.Markdown("*HR* : high resolution; *matting* : better with transparency; *lite* : faster.")
|
||||
model = gr.Dropdown(label="Model (download on selection, see console for progress)",
|
||||
choices=["BiRefNet_512x512", "BiRefNet", "BiRefNet_HR", "BiRefNet-matting", "BiRefNet_HR-matting", "BiRefNet_lite", "BiRefNet_lite-2K", "BiRefNet-portrait", "BiRefNet-COD", "BiRefNet-DIS5K", "BiRefNet-DIS5k-TR_TEs", "BiRefNet-HRSOD"], value="BiRefNet_HR", type="value")
|
||||
|
||||
gr.Markdown("Regular models trained at 1024 \u00D7 1024; HR models trained at 2048 \u00D7 2048; 2K model trained at 2560 \u00D7 1440.")
|
||||
gr.Markdown("Greater processing image size will typically give more accurate results, but also requires more VRAM (shared memory works well).")
|
||||
with gr.Row():
|
||||
proc_sizeW = gr.Slider(label="birefnet processing image width",
|
||||
minimum=256, maximum=2560, value=2048, step=32)
|
||||
proc_sizeH = gr.Slider(label="birefnet processing image height",
|
||||
minimum=256, maximum=2048, value=2048, step=32)
|
||||
with gr.Row():
|
||||
save_flat = gr.Checkbox(label="Save flat (no mask)", value=False)
|
||||
bg_colour = gr.ColorPicker(label="Background colour for saving flat, and video", value="#00FF00", visible=True, interactive=True)
|
||||
|
||||
model.change(fn=load_model, inputs=model, outputs=None)
|
||||
|
||||
|
||||
go_image.click(fn=common_setup, inputs=[proc_size]).then(fn=process, inputs=image, outputs=result1)
|
||||
go_text.click( fn=common_setup, inputs=[proc_size]).then(fn=process, inputs=text, outputs=result2)
|
||||
go_batch.click(fn=common_setup, inputs=[proc_size]).then(fn=batch_process, inputs=[input_dir, output_dir, always_png, save_flat], outputs=result3)
|
||||
gr.Markdown("### https://github.com/ZhengPeng7/BiRefNet\n### https://huggingface.co/ZhengPeng7")
|
||||
|
||||
go_image.click(fn=common_setup, inputs=[proc_sizeW, proc_sizeH]).then(fn=process, inputs=[image, save_flat, bg_colour], outputs=result1)
|
||||
go_text.click( fn=common_setup, inputs=[proc_sizeW, proc_sizeH]).then(fn=process, inputs=[text, save_flat, bg_colour], outputs=result2)
|
||||
if got_mp:
|
||||
go_video.click(fn=common_setup, inputs=[proc_sizeW, proc_sizeH]).then(
|
||||
fn=video_process, inputs=[video, bg_colour], outputs=result4)
|
||||
go_batch.click(fn=common_setup, inputs=[proc_sizeW, proc_sizeH]).then(
|
||||
fn=batch_process, inputs=[input_dir, output_dir, always_png, save_flat, bg_colour], outputs=result3)
|
||||
|
||||
demo.unload(unload)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user