Forge Space and BiRefNet

This commit is contained in:
lllyasviel
2024-08-17 08:29:08 -07:00
committed by GitHub
parent 8a04293430
commit 93b40f355e
13 changed files with 390 additions and 13 deletions

1
.gitignore vendored
View File

@@ -1,3 +1,4 @@
huggingface_space_mirror/
random_test.py
__pycache__
*.ckpt

View File

@@ -0,0 +1,67 @@
import spaces
import os
import gradio as gr
from gradio_imageslider import ImageSlider
from loadimg import load_img
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
torch.set_float32_matmul_precision(["high", "highest"][0])
os.environ['HOME'] = spaces.convert_root_path() + 'home'
with spaces.GPUObject() as birefnet_gpu_obj:
birefnet = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet", trust_remote_code=True
)
transform_image = transforms.Compose(
[
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
@spaces.GPU(gpu_objects=[birefnet_gpu_obj])
def fn(image):
im = load_img(image, output_type="pil")
im = im.convert("RGB")
image_size = im.size
origin = im.copy()
image = load_img(im)
input_images = transform_image(image).unsqueeze(0).to(spaces.gpu)
# Prediction
with torch.no_grad():
preds = birefnet(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
image.putalpha(mask)
return (image, origin)
slider1 = ImageSlider(label="birefnet", type="pil")
slider2 = ImageSlider(label="birefnet", type="pil")
image = gr.Image(label="Upload an image")
text = gr.Textbox(label="Paste an image URL")
chameleon = load_img(spaces.convert_root_path() + "chameleon.jpg", output_type="pil")
url = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
tab1 = gr.Interface(
fn, inputs=image, outputs=slider1, examples=[chameleon], api_name="image", allow_flagging="never"
)
tab2 = gr.Interface(fn, inputs=text, outputs=slider2, examples=[url], api_name="text", allow_flagging="never")
demo = gr.TabbedInterface(
[tab1, tab2], ["image", "text"], title="birefnet for background removal"
)
if __name__ == "__main__":
demo.launch(inbrowser=True)

View File

@@ -0,0 +1,5 @@
{
"tag": "Image Processing: Matting, Saliency, and Background Removal",
"title": "BiRefNet for Background Removal",
"repo_id": "not-lain/background-removal"
}

View File

@@ -5,6 +5,7 @@ import dataclasses
import os
import threading
import re
import json
from modules import shared, errors, cache, scripts
from modules.gitpython_hack import Repo
@@ -124,6 +125,13 @@ class Extension:
self.metadata = metadata if metadata else ExtensionMetadata(self.path, name.lower())
self.canonical_name = metadata.canonical_name
self.is_forge_space = False
self.space_meta = None
if os.path.exists(os.path.join(self.path, 'space_meta.json')) and os.path.exists(os.path.join(self.path, 'forge_app.py')):
self.is_forge_space = True
self.space_meta = json.load(open(os.path.join(self.path, 'space_meta.json'), 'rt', encoding='utf-8'))
def to_dict(self):
return {x: getattr(self, x) for x in self.cached_fields}

View File

@@ -1,4 +1,5 @@
import inspect
import types
import warnings
from functools import wraps
@@ -106,6 +107,25 @@ gradio_component_meta_create_or_modify_pyi_original = patches.patch(__file__, gr
# this function is broken and does not seem to do anything useful
gradio.component_meta.updateable = lambda x: x
class EventWrapper:
def __init__(self, replaced_event):
self.replaced_event = replaced_event
self.has_trigger = replaced_event.has_trigger
self.event_name = replaced_event.event_name
self.callback = replaced_event.callback
def __call__(self, *args, **kwargs):
if '_js' in kwargs:
kwargs['js'] = kwargs['_js']
del kwargs['_js']
return self.replaced_event(*args, **kwargs)
@property
def __self__(self):
return self.replaced_event.__self__
def repair(grclass):
if not getattr(grclass, 'EVENTS', None):
return
@@ -129,13 +149,7 @@ def repair(grclass):
for event in self.EVENTS:
replaced_event = getattr(self, str(event))
def fun(*xargs, _js=None, replaced_event=replaced_event, **xkwargs):
if _js:
xkwargs['js'] = _js
return replaced_event(*xargs, **xkwargs)
fun = EventWrapper(replaced_event)
setattr(self, str(event), fun)
grclass.__init__ = __repaired_init__

View File

@@ -26,7 +26,7 @@ import modules.shared as shared
from modules import prompt_parser
from modules.infotext_utils import image_from_url_text, PasteField
from modules_forge.forge_canvas.canvas import ForgeCanvas, canvas_head
from modules_forge import main_entry
from modules_forge import main_entry, forge_space
create_setting_component = ui_settings.create_setting_component
@@ -853,6 +853,9 @@ def create_ui():
extra_tabs.__exit__()
with gr.Blocks(analytics_enabled=False, head=canvas_head) as space_interface:
forge_space.main_entry()
scripts.scripts_current = None
with gr.Blocks(analytics_enabled=False, head=canvas_head) as extras_interface:
@@ -891,6 +894,7 @@ def create_ui():
interfaces = [
(txt2img_interface, "Txt2img", "txt2img"),
(img2img_interface, "Img2img", "img2img"),
(space_interface, "Spaces", "space"),
(extras_interface, "Extras", "extras"),
(pnginfo_interface, "PNG Info", "pnginfo"),
(modelmerger_ui.blocks, "Checkpoint Merger", "modelmerger"),

View File

@@ -0,0 +1,159 @@
import os
import sys
import uuid
import time
import gradio as gr
import importlib.util
import shutil
from gradio.context import Context
from threading import Thread
from huggingface_hub import snapshot_download
from backend import memory_management
spaces = []
def build_html(title, installed=False, url=None):
if not installed:
return f'<div>{title}</div><div style="color: grey;">Not Installed</div>'
if isinstance(url, str):
return f'<div>{title}</div><div>Currently Running: <a href="{url}" style="color: green;" target="_blank">{url}</a></div>'
else:
return f'<div>{title}</div><div style="color: grey;">Installed, Ready to Launch</div>'
class ForgeSpace:
def __init__(self, root_path, title, repo_id=None, repo_type='space', revision=None, **kwargs):
self.title = title
self.root_path = root_path
self.hf_path = os.path.join(root_path, 'huggingface_space_mirror')
self.repo_id = repo_id
self.repo_type = repo_type
self.revision = revision
self.is_running = False
self.gradio_metas = None
self.label = gr.HTML(build_html(title=title, url=None), elem_classes=['forge_space_label'])
self.btn_launch = gr.Button('Launch', elem_classes=['forge_space_btn'])
self.btn_terminate = gr.Button('Terminate', elem_classes=['forge_space_btn'])
self.btn_install = gr.Button('Install', elem_classes=['forge_space_btn'])
self.btn_uninstall = gr.Button('Uninstall', elem_classes=['forge_space_btn'])
comps = [
self.label,
self.btn_install,
self.btn_uninstall,
self.btn_launch,
self.btn_terminate
]
self.btn_launch.click(self.run, outputs=comps)
self.btn_terminate.click(self.terminate, outputs=comps)
self.btn_install.click(self.install, outputs=comps)
self.btn_uninstall.click(self.uninstall, outputs=comps)
Context.root_block.load(self.refresh_gradio, outputs=comps, queue=False, show_progress=False)
return
def refresh_gradio(self):
results = []
installed = os.path.exists(self.hf_path)
if isinstance(self.gradio_metas, tuple):
results.append(build_html(title=self.title, installed=installed, url=self.gradio_metas[1]))
else:
results.append(build_html(title=self.title, installed=installed, url=None))
results.append(gr.update(interactive=not installed))
results.append(gr.update(interactive=installed))
results.append(gr.update(interactive=installed and not self.is_running))
results.append(gr.update(interactive=installed and self.is_running))
return results
def install(self):
os.makedirs(self.hf_path, exist_ok=True)
if self.repo_id is None:
return self.refresh_gradio()
downloaded = snapshot_download(
repo_id=self.repo_id,
repo_type=self.repo_type,
revision=self.revision,
local_dir=self.hf_path,
force_download=True,
)
print(f'Downloaded: {downloaded}')
return self.refresh_gradio()
def uninstall(self):
shutil.rmtree(self.hf_path)
print(f'Deleted: {self.hf_path}')
return self.refresh_gradio()
def terminate(self):
self.is_running = False
while self.gradio_metas is not None:
time.sleep(0.1)
return self.refresh_gradio()
def run(self):
self.is_running = True
Thread(target=self.gradio_worker).start()
while self.gradio_metas is None:
time.sleep(0.1)
return self.refresh_gradio()
def gradio_worker(self):
memory_management.unload_all_models()
sys.path.insert(0, self.hf_path)
file_path = os.path.join(self.root_path, 'forge_app.py')
module_name = 'forge_space_' + str(uuid.uuid4()).replace('-', '_')
spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
demo = getattr(module, 'demo')
self.gradio_metas = demo.launch(inbrowser=True, prevent_thread_lock=True)
while self.is_running:
time.sleep(0.1)
demo.close()
self.gradio_metas = None
if module_name in sys.modules:
del sys.modules[module_name]
return
def main_entry():
global spaces
from modules.extensions import extensions
tagged_extensions = {}
for ex in extensions:
if ex.enabled and ex.is_forge_space:
tag = ex.space_meta['tag']
if tag not in tagged_extensions:
tagged_extensions[tag] = []
tagged_extensions[tag].append(ex)
for tag, exs in tagged_extensions.items():
with gr.Accordion(tag, open=True):
for ex in exs:
with gr.Row(equal_height=True):
space = ForgeSpace(root_path=ex.path, **ex.space_meta)
spaces.append(space)
return

View File

@@ -2,6 +2,7 @@ import os
import sys
INITIALIZED = False
MONITOR_MODEL_MOVING = False
@@ -25,6 +26,13 @@ def monitor_module_moving():
def initialize_forge():
global INITIALIZED
if INITIALIZED:
return
INITIALIZED = True
sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(__file__)), 'packages_3rdparty'))
bad_list = ['--lowvram', '--medvram', '--medvram-sdxl']
@@ -60,9 +68,6 @@ def initialize_forge():
from modules_forge.bnb_installer import try_install_bnb
try_install_bnb()
import modules_forge.patch_basic
modules_forge.patch_basic.patch_all_basics()
from backend import stream
print('CUDA Using Stream:', stream.should_use_stream())
@@ -85,4 +90,8 @@ def initialize_forge():
if 'HF_HUB_CACHE' not in os.environ:
os.environ['HF_HUB_CACHE'] = diffusers_dir
import modules_forge.patch_basic
modules_forge.patch_basic.patch_all_basics()
return

View File

@@ -6,6 +6,8 @@ import warnings
import gradio.networking
import safetensors.torch
from tqdm import tqdm
def gradio_url_ok_fix(url: str) -> bool:
try:
@@ -55,7 +57,20 @@ def build_loaded(module, loader_name):
return
def always_show_tqdm(*args, **kwargs):
kwargs['disable'] = False
if 'name' in kwargs:
del kwargs['name']
return tqdm(*args, **kwargs)
def patch_all_basics():
import logging
from huggingface_hub import file_download
file_download.tqdm = always_show_tqdm
from transformers.dynamic_module_utils import logger
logger.setLevel(logging.ERROR)
gradio.networking.url_ok = gradio_url_ok_fix
build_loaded(safetensors.torch, 'load_file')
build_loaded(torch, 'load')

View File

@@ -35,4 +35,6 @@ httpx==0.24.1
pillow-avif-plugin==1.4.3
diffusers==0.29.2
gradio_rangeslider==0.0.6
gradio_imageslider==0.0.20
loadimg==0.1.2
tqdm==4.66.1

79
spaces.py Normal file
View File

@@ -0,0 +1,79 @@
from modules_forge.initialization import initialize_forge
initialize_forge()
import os
import torch
import inspect
from backend import memory_management
gpu = memory_management.get_torch_device()
class GPUObject:
def __init__(self):
self.module_list = []
def __enter__(self):
self.original_init = torch.nn.Module.__init__
self.original_to = torch.nn.Module.to
def patched_init(module, *args, **kwargs):
self.module_list.append(module)
return self.original_init(module, *args, **kwargs)
def patched_to(module, *args, **kwargs):
self.module_list.append(module)
return self.original_to(module, *args, **kwargs)
torch.nn.Module.__init__ = patched_init
torch.nn.Module.to = patched_to
return self
def __exit__(self, exc_type, exc_val, exc_tb):
torch.nn.Module.__init__ = self.original_init
torch.nn.Module.to = self.original_to
self.module_list = set(self.module_list)
self.to(device=torch.device('cpu'))
memory_management.soft_empty_cache()
return
def to(self, device):
for module in self.module_list:
module.to(device)
print(f'Forge Space: Moved {len(self.module_list)} Modules to {device}')
return self
def gpu(self):
self.to(device=gpu)
return self
def GPU(gpu_objects=None, manual_load=False):
gpu_objects = gpu_objects or []
def decorator(func):
def wrapper(*args, **kwargs):
print("Entering Forge Space GPU ...")
memory_management.unload_all_models()
if not manual_load:
for o in gpu_objects:
o.gpu()
result = func(*args, **kwargs)
print("Cleaning Forge Space GPU ...")
for o in gpu_objects:
o.to(device=torch.device('cpu'))
memory_management.soft_empty_cache()
return result
return wrapper
return decorator
def convert_root_path():
frame = inspect.currentframe().f_back
caller_file = frame.f_code.co_filename
caller_file = os.path.abspath(caller_file)
result = os.path.join(os.path.dirname(caller_file), 'huggingface_space_mirror')
return result + '/'

View File

@@ -1673,3 +1673,17 @@ body.resizing .resize-handle {
#quicksettings .gradio-slider span {
padding-right: 5px;
}
.forge_space_label{
padding: 10px;
min-width: 60% !important;
margin: 1px;
border-width: 1px;
border-radius: 8px;
border-style: solid;
border-color: #6f6f6f;
}
.forge_space_btn{
min-width: 0 !important;
}

View File

@@ -11,8 +11,6 @@ from modules_forge.initialization import initialize_forge
from modules_forge import main_thread
from modules_forge.forge_canvas.canvas import canvas_js_root_path
startup_timer = timer.startup_timer
startup_timer.record("launcher")
@@ -83,6 +81,8 @@ def webui_worker():
elif shared.opts.auto_launch_browser == "Local":
auto_launch_browser = not cmd_opts.webui_is_non_local
from modules_forge.forge_canvas.canvas import canvas_js_root_path
app, local_url, share_url = shared.demo.launch(
share=cmd_opts.share,
server_name=initialize_util.gradio_server_name(),