From a9e247eb2e851e1b5018124774559c1c5cb2499d Mon Sep 17 00:00:00 2001 From: DenOfEquity <166248528+DenOfEquity@users.noreply.github.com> Date: Wed, 12 Feb 2025 18:50:49 +0000 Subject: [PATCH] birefnet Space: more models, video (#2648) --- .../forge_space_birefnet/forge_app.py | 146 +++++++++++++++--- 1 file changed, 122 insertions(+), 24 deletions(-) diff --git a/extensions-builtin/forge_space_birefnet/forge_app.py b/extensions-builtin/forge_space_birefnet/forge_app.py index ba2b98e9..46060bef 100644 --- a/extensions-builtin/forge_space_birefnet/forge_app.py +++ b/extensions-builtin/forge_space_birefnet/forge_app.py @@ -4,6 +4,12 @@ import os import gradio as gr import gc +try: + import moviepy.editor as mp + got_mp = True +except: + got_mp = False + from loadimg import load_img from transformers import AutoModelForImageSegmentation import torch @@ -12,6 +18,7 @@ from torchvision import transforms import glob import pathlib from PIL import Image +import numpy transform_image = None @@ -24,7 +31,7 @@ def load_model(model): torch.cuda.empty_cache() birefnet = AutoModelForImageSegmentation.from_pretrained( - model, trust_remote_code=True + "ZhengPeng7/"+model, trust_remote_code=True ) birefnet.eval() birefnet.half() @@ -32,14 +39,14 @@ def load_model(model): spaces.automatically_move_to_gpu_when_forward(birefnet) with spaces.capture_gpu_object() as birefnet_gpu_obj: - load_model("ZhengPeng7/BiRefNet_HR") + load_model("BiRefNet_HR") -def common_setup(size): +def common_setup(w, h): global transform_image transform_image = transforms.Compose( [ - transforms.Resize((size, size)), + transforms.Resize((w, h)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] @@ -47,7 +54,7 @@ def common_setup(size): @spaces.GPU(gpu_objects=[birefnet_gpu_obj], manual_load=True) -def process(image): +def process(image, save_flat, bg_colour): im = load_img(image, output_type="pil") im = im.convert("RGB") image_size = im.size @@ -60,25 +67,92 @@ def process(image): pred_pil = transforms.ToPILImage()(pred) mask = pred_pil.resize(image_size) image.putalpha(mask) + + if save_flat: + bg_colour += "FF" + colour_rgb = tuple(int(bg_colour[i:i+2], 16) for i in (1, 3, 5, 7)) + background = Image.new("RGBA", image_size, colour_rgb) + image = Image.alpha_composite(background, image) + image = image.convert("RGB") + return image +# video processing based on https://huggingface.co/spaces/brokerrobin/video-background-removal/blob/main/app.py +@spaces.GPU(gpu_objects=[birefnet_gpu_obj], manual_load=True) +def video_process(video, bg_colour): + # Load the video using moviepy + video = mp.VideoFileClip(video) + + fps = video.fps + + # Extract audio from the video + audio = video.audio + + # Extract frames at the specified FPS + frames = video.iter_frames(fps=fps) + + # Process each frame for background removal + processed_frames = [] + + for i, frame in enumerate(frames): + print (f"birefnet [video]: frame {i+1}", end='\r', flush=True) + + image = Image.fromarray(frame) + + if i == 0: + image_size = image.size + + colour_rgb = tuple(int(bg_colour[i:i+2], 16) for i in (1, 3, 5)) + background = Image.new("RGBA", image_size, colour_rgb + (255,)) + + input_image = transform_image(image).unsqueeze(0).to(spaces.gpu).to(torch.float16) + # Prediction + with torch.no_grad(): + preds = birefnet(input_image)[-1].sigmoid().cpu() + pred = preds[0].squeeze() + pred_pil = transforms.ToPILImage()(pred) + mask = pred_pil.resize(image_size) + + # Apply mask and composite + image.putalpha(mask) + processed_image = Image.alpha_composite(background, image) + + processed_frames.append(numpy.array(processed_image)) + + # Create a new video from the processed frames + processed_video = mp.ImageSequenceClip(processed_frames, fps=fps) + + # Add the original audio back to the processed video + processed_video = processed_video.set_audio(audio) + + # Save the processed video using modified original filename (goes to gradio temp) + filename, _ = os.path.splitext(video.filename) + filename += "-birefnet.mp4" + processed_video.write_videofile(filename, codec="libx264") + + return filename + @spaces.GPU(gpu_objects=[birefnet_gpu_obj], manual_load=True) -def batch_process(input_folder, output_folder, save_png, save_flat): +def batch_process(input_folder, output_folder, save_png, save_flat, bg_colour): # Ensure output folder exists os.makedirs(output_folder, exist_ok=True) # Supported image extensions - image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.webp', ".avif"] + image_extensions = ['.jpg', '.jpeg', '.jfif', '.png', '.bmp', '.webp', ".avif"] # Collect all image files from input folder input_images = [] for ext in image_extensions: input_images.extend(glob.glob(os.path.join(input_folder, f'*{ext}'))) - + + if save_flat: + bg_colour += "FF" + colour_rgb = tuple(int(bg_colour[i:i+2], 16) for i in (1, 3, 5, 7)) # Process each image processed_images = [] - for image_path in input_images: + for i, image_path in enumerate(input_images): + print (f"birefnet [batch]: image {i+1}", end='\r', flush=True) try: # Load image im = load_img(image_path, output_type="pil") @@ -104,7 +178,7 @@ def batch_process(input_folder, output_folder, save_png, save_flat): output_filename = os.path.join(output_folder, f"{pathlib.Path(image_path).name}") if save_flat: - background = Image.new('RGBA', image.size, (255, 255, 255)) + background = Image.new("RGBA", image_size, colour_rgb) image = Image.alpha_composite(background, image) image = image.convert("RGB") elif output_filename.lower().endswith(".jpg") or output_filename.lower().endswith(".jpeg"): @@ -146,10 +220,10 @@ with gr.Blocks(css=css, analytics_enabled=False) as demo: with gr.Tab("image"): with gr.Row(): with gr.Column(): - image = gr.Image(label="Upload an image", type='pil', height=616) + image = gr.Image(label="Upload an image", type='pil', height=584) go_image = gr.Button("Remove background") with gr.Column(): - result1 = gr.Image(label="birefnet", type="pil", height=576) + result1 = gr.Image(label="birefnet", type="pil", height=544) with gr.Tab("URL"): with gr.Row(): @@ -157,31 +231,55 @@ with gr.Blocks(css=css, analytics_enabled=False) as demo: text = gr.Textbox(label="URL to image, or local path to image", max_lines=1) go_text = gr.Button("Remove background") with gr.Column(): - result2 = gr.Image(label="birefnet", type="pil", height=576) + result2 = gr.Image(label="birefnet", type="pil", height=544) + + if got_mp: + with gr.Tab("video"): + with gr.Row(): + with gr.Column(): + video = gr.Video(label="Upload a video", height=584) + go_video = gr.Button("Remove background") + with gr.Column(): + result4 = gr.Video(label="birefnet", height=544, show_share_button=False) with gr.Tab("batch"): with gr.Row(): with gr.Column(): input_dir = gr.Textbox(label="Input folder path", max_lines=1) - output_dir = gr.Textbox(label="Output folder path (will overwrite)", max_lines=1) + output_dir = gr.Textbox(label="Output folder path (save images will overwrite)", max_lines=1) always_png = gr.Checkbox(label="Always save as PNG", value=True) - save_flat = gr.Checkbox(label="Save flat (no mask)", value=False) + go_batch = gr.Button("Remove background(s)") with gr.Column(): result3 = gr.File(label="Processed image(s)", type="filepath", file_count="multiple") with gr.Tab("options"): - model = gr.Dropdown(label="Model", - choices=["ZhengPeng7/BiRefNet", "ZhengPeng7/BiRefNet_HR"], value="ZhengPeng7/BiRefNet_HR", type="value") - proc_size = gr.Dropdown(label="birefnet processing image size", info="1024: old model; 2048: HR model - more accurate, uses more VRAM (shared memory works well)", - choices=[1024, 1536, 2048], value=2048) - + gr.Markdown("*HR* : high resolution; *matting* : better with transparency; *lite* : faster.") + model = gr.Dropdown(label="Model (download on selection, see console for progress)", + choices=["BiRefNet_512x512", "BiRefNet", "BiRefNet_HR", "BiRefNet-matting", "BiRefNet_HR-matting", "BiRefNet_lite", "BiRefNet_lite-2K", "BiRefNet-portrait", "BiRefNet-COD", "BiRefNet-DIS5K", "BiRefNet-DIS5k-TR_TEs", "BiRefNet-HRSOD"], value="BiRefNet_HR", type="value") + + gr.Markdown("Regular models trained at 1024 \u00D7 1024; HR models trained at 2048 \u00D7 2048; 2K model trained at 2560 \u00D7 1440.") + gr.Markdown("Greater processing image size will typically give more accurate results, but also requires more VRAM (shared memory works well).") + with gr.Row(): + proc_sizeW = gr.Slider(label="birefnet processing image width", + minimum=256, maximum=2560, value=2048, step=32) + proc_sizeH = gr.Slider(label="birefnet processing image height", + minimum=256, maximum=2048, value=2048, step=32) + with gr.Row(): + save_flat = gr.Checkbox(label="Save flat (no mask)", value=False) + bg_colour = gr.ColorPicker(label="Background colour for saving flat, and video", value="#00FF00", visible=True, interactive=True) + model.change(fn=load_model, inputs=model, outputs=None) - - go_image.click(fn=common_setup, inputs=[proc_size]).then(fn=process, inputs=image, outputs=result1) - go_text.click( fn=common_setup, inputs=[proc_size]).then(fn=process, inputs=text, outputs=result2) - go_batch.click(fn=common_setup, inputs=[proc_size]).then(fn=batch_process, inputs=[input_dir, output_dir, always_png, save_flat], outputs=result3) + gr.Markdown("### https://github.com/ZhengPeng7/BiRefNet\n### https://huggingface.co/ZhengPeng7") + + go_image.click(fn=common_setup, inputs=[proc_sizeW, proc_sizeH]).then(fn=process, inputs=[image, save_flat, bg_colour], outputs=result1) + go_text.click( fn=common_setup, inputs=[proc_sizeW, proc_sizeH]).then(fn=process, inputs=[text, save_flat, bg_colour], outputs=result2) + if got_mp: + go_video.click(fn=common_setup, inputs=[proc_sizeW, proc_sizeH]).then( + fn=video_process, inputs=[video, bg_colour], outputs=result4) + go_batch.click(fn=common_setup, inputs=[proc_sizeW, proc_sizeH]).then( + fn=batch_process, inputs=[input_dir, output_dir, always_png, save_flat, bg_colour], outputs=result3) demo.unload(unload)