diff --git a/extensions-builtin/forge_space_birefnet/forge_app.py b/extensions-builtin/forge_space_birefnet/forge_app.py index 96b686b2..ba2b98e9 100644 --- a/extensions-builtin/forge_space_birefnet/forge_app.py +++ b/extensions-builtin/forge_space_birefnet/forge_app.py @@ -2,7 +2,8 @@ import spaces import os import gradio as gr -from gradio_imageslider import ImageSlider +import gc + from loadimg import load_img from transformers import AutoModelForImageSegmentation import torch @@ -13,38 +14,54 @@ import pathlib from PIL import Image -with spaces.capture_gpu_object() as birefnet_gpu_obj: +transform_image = None +birefnet = None + +def load_model(model): + global birefnet + birefnet = None + gc.collect() + torch.cuda.empty_cache() + birefnet = AutoModelForImageSegmentation.from_pretrained( - "ZhengPeng7/BiRefNet", trust_remote_code=True + model, trust_remote_code=True ) + birefnet.eval() + birefnet.half() -spaces.automatically_move_to_gpu_when_forward(birefnet) + spaces.automatically_move_to_gpu_when_forward(birefnet) -transform_image = transforms.Compose( - [ - transforms.Resize((1024, 1024)), - transforms.ToTensor(), - transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), - ] -) +with spaces.capture_gpu_object() as birefnet_gpu_obj: + load_model("ZhengPeng7/BiRefNet_HR") + +def common_setup(size): + global transform_image + + transform_image = transforms.Compose( + [ + transforms.Resize((size, size)), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ] + ) @spaces.GPU(gpu_objects=[birefnet_gpu_obj], manual_load=True) -def fn(image): +def process(image): im = load_img(image, output_type="pil") im = im.convert("RGB") image_size = im.size - origin = im.copy() image = load_img(im) - input_images = transform_image(image).unsqueeze(0).to(spaces.gpu) + input_image = transform_image(image).unsqueeze(0).to(spaces.gpu).to(torch.float16) # Prediction with torch.no_grad(): - preds = birefnet(input_images)[-1].sigmoid().cpu() + preds = birefnet(input_image)[-1].sigmoid().cpu() pred = preds[0].squeeze() pred_pil = transforms.ToPILImage()(pred) mask = pred_pil.resize(image_size) image.putalpha(mask) - return (image, origin) + return image + @spaces.GPU(gpu_objects=[birefnet_gpu_obj], manual_load=True) def batch_process(input_folder, output_folder, save_png, save_flat): @@ -52,7 +69,7 @@ def batch_process(input_folder, output_folder, save_png, save_flat): os.makedirs(output_folder, exist_ok=True) # Supported image extensions - image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.webp'] + image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.webp', ".avif"] # Collect all image files from input folder input_images = [] @@ -70,8 +87,8 @@ def batch_process(input_folder, output_folder, save_png, save_flat): image = load_img(im) # Prepare image for processing - input_image = transform_image(image).unsqueeze(0).to(spaces.gpu) - + input_image = transform_image(image).unsqueeze(0).to(spaces.gpu).to(torch.float16) + # Prediction with torch.no_grad(): preds = birefnet(input_image)[-1].sigmoid().cpu() @@ -105,41 +122,68 @@ def batch_process(input_folder, output_folder, save_png, save_flat): return processed_images -slider1 = ImageSlider(label="birefnet", type="pil") -slider2 = ImageSlider(label="birefnet", type="pil") -image = gr.Image(label="Upload an image") -text = gr.Textbox(label="URL to image, or local path to image", max_lines=1) + +def unload(): + global birefnet, transform_image + birefnet = None + transform_image = None + gc.collect() + torch.cuda.empty_cache() -chameleon = load_img(spaces.convert_root_path() + "chameleon.jpg", output_type="pil") +css = """ +.gradio-container { + max-width: 1280px !important; +} +footer { + display: none !important; +} +""" -url = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg" -tab1 = gr.Interface( - fn, inputs=image, outputs=slider1, examples=[chameleon], api_name="image", allow_flagging="never" -) +with gr.Blocks(css=css, analytics_enabled=False) as demo: + gr.Markdown("# birefnet for background removal") -tab2 = gr.Interface( - fn, inputs=text, outputs=slider2, examples=[url], api_name="text", allow_flagging="never" -) + with gr.Tab("image"): + with gr.Row(): + with gr.Column(): + image = gr.Image(label="Upload an image", type='pil', height=616) + go_image = gr.Button("Remove background") + with gr.Column(): + result1 = gr.Image(label="birefnet", type="pil", height=576) -tab3 = gr.Interface( - batch_process, - inputs=[ - gr.Textbox(label="Input folder path", max_lines=1), - gr.Textbox(label="Output folder path (will overwrite)", max_lines=1), - gr.Checkbox(label="Always save as PNG", value=True), - gr.Checkbox(label="Save flat (no mask)", value=False) - ], - outputs=gr.File(label="Processed images", type="filepath", file_count="multiple"), - api_name="batch", - allow_flagging="never" -) + with gr.Tab("URL"): + with gr.Row(): + with gr.Column(): + text = gr.Textbox(label="URL to image, or local path to image", max_lines=1) + go_text = gr.Button("Remove background") + with gr.Column(): + result2 = gr.Image(label="birefnet", type="pil", height=576) -demo = gr.TabbedInterface( - [tab1, tab2, tab3], - ["image", "URL", "batch"], - title="birefnet for background removal" -) + with gr.Tab("batch"): + with gr.Row(): + with gr.Column(): + input_dir = gr.Textbox(label="Input folder path", max_lines=1) + output_dir = gr.Textbox(label="Output folder path (will overwrite)", max_lines=1) + always_png = gr.Checkbox(label="Always save as PNG", value=True) + save_flat = gr.Checkbox(label="Save flat (no mask)", value=False) + go_batch = gr.Button("Remove background(s)") + with gr.Column(): + result3 = gr.File(label="Processed image(s)", type="filepath", file_count="multiple") + + with gr.Tab("options"): + model = gr.Dropdown(label="Model", + choices=["ZhengPeng7/BiRefNet", "ZhengPeng7/BiRefNet_HR"], value="ZhengPeng7/BiRefNet_HR", type="value") + proc_size = gr.Dropdown(label="birefnet processing image size", info="1024: old model; 2048: HR model - more accurate, uses more VRAM (shared memory works well)", + choices=[1024, 1536, 2048], value=2048) + + model.change(fn=load_model, inputs=model, outputs=None) + + + go_image.click(fn=common_setup, inputs=[proc_size]).then(fn=process, inputs=image, outputs=result1) + go_text.click( fn=common_setup, inputs=[proc_size]).then(fn=process, inputs=text, outputs=result2) + go_batch.click(fn=common_setup, inputs=[proc_size]).then(fn=batch_process, inputs=[input_dir, output_dir, always_png, save_flat], outputs=result3) + + demo.unload(unload) if __name__ == "__main__": demo.launch(inbrowser=True)