Forge Space and BiRefNet

This commit is contained in:
lllyasviel
2024-08-17 08:29:08 -07:00
committed by GitHub
parent 8a04293430
commit 93b40f355e
13 changed files with 390 additions and 13 deletions

1
.gitignore vendored
View File

@@ -1,3 +1,4 @@
huggingface_space_mirror/
random_test.py random_test.py
__pycache__ __pycache__
*.ckpt *.ckpt

View File

@@ -0,0 +1,67 @@
import spaces
import os
import gradio as gr
from gradio_imageslider import ImageSlider
from loadimg import load_img
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
torch.set_float32_matmul_precision(["high", "highest"][0])
os.environ['HOME'] = spaces.convert_root_path() + 'home'
with spaces.GPUObject() as birefnet_gpu_obj:
birefnet = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet", trust_remote_code=True
)
transform_image = transforms.Compose(
[
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
@spaces.GPU(gpu_objects=[birefnet_gpu_obj])
def fn(image):
im = load_img(image, output_type="pil")
im = im.convert("RGB")
image_size = im.size
origin = im.copy()
image = load_img(im)
input_images = transform_image(image).unsqueeze(0).to(spaces.gpu)
# Prediction
with torch.no_grad():
preds = birefnet(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
image.putalpha(mask)
return (image, origin)
slider1 = ImageSlider(label="birefnet", type="pil")
slider2 = ImageSlider(label="birefnet", type="pil")
image = gr.Image(label="Upload an image")
text = gr.Textbox(label="Paste an image URL")
chameleon = load_img(spaces.convert_root_path() + "chameleon.jpg", output_type="pil")
url = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
tab1 = gr.Interface(
fn, inputs=image, outputs=slider1, examples=[chameleon], api_name="image", allow_flagging="never"
)
tab2 = gr.Interface(fn, inputs=text, outputs=slider2, examples=[url], api_name="text", allow_flagging="never")
demo = gr.TabbedInterface(
[tab1, tab2], ["image", "text"], title="birefnet for background removal"
)
if __name__ == "__main__":
demo.launch(inbrowser=True)

View File

@@ -0,0 +1,5 @@
{
"tag": "Image Processing: Matting, Saliency, and Background Removal",
"title": "BiRefNet for Background Removal",
"repo_id": "not-lain/background-removal"
}

View File

@@ -5,6 +5,7 @@ import dataclasses
import os import os
import threading import threading
import re import re
import json
from modules import shared, errors, cache, scripts from modules import shared, errors, cache, scripts
from modules.gitpython_hack import Repo from modules.gitpython_hack import Repo
@@ -124,6 +125,13 @@ class Extension:
self.metadata = metadata if metadata else ExtensionMetadata(self.path, name.lower()) self.metadata = metadata if metadata else ExtensionMetadata(self.path, name.lower())
self.canonical_name = metadata.canonical_name self.canonical_name = metadata.canonical_name
self.is_forge_space = False
self.space_meta = None
if os.path.exists(os.path.join(self.path, 'space_meta.json')) and os.path.exists(os.path.join(self.path, 'forge_app.py')):
self.is_forge_space = True
self.space_meta = json.load(open(os.path.join(self.path, 'space_meta.json'), 'rt', encoding='utf-8'))
def to_dict(self): def to_dict(self):
return {x: getattr(self, x) for x in self.cached_fields} return {x: getattr(self, x) for x in self.cached_fields}

View File

@@ -1,4 +1,5 @@
import inspect import inspect
import types
import warnings import warnings
from functools import wraps from functools import wraps
@@ -106,6 +107,25 @@ gradio_component_meta_create_or_modify_pyi_original = patches.patch(__file__, gr
# this function is broken and does not seem to do anything useful # this function is broken and does not seem to do anything useful
gradio.component_meta.updateable = lambda x: x gradio.component_meta.updateable = lambda x: x
class EventWrapper:
def __init__(self, replaced_event):
self.replaced_event = replaced_event
self.has_trigger = replaced_event.has_trigger
self.event_name = replaced_event.event_name
self.callback = replaced_event.callback
def __call__(self, *args, **kwargs):
if '_js' in kwargs:
kwargs['js'] = kwargs['_js']
del kwargs['_js']
return self.replaced_event(*args, **kwargs)
@property
def __self__(self):
return self.replaced_event.__self__
def repair(grclass): def repair(grclass):
if not getattr(grclass, 'EVENTS', None): if not getattr(grclass, 'EVENTS', None):
return return
@@ -129,13 +149,7 @@ def repair(grclass):
for event in self.EVENTS: for event in self.EVENTS:
replaced_event = getattr(self, str(event)) replaced_event = getattr(self, str(event))
fun = EventWrapper(replaced_event)
def fun(*xargs, _js=None, replaced_event=replaced_event, **xkwargs):
if _js:
xkwargs['js'] = _js
return replaced_event(*xargs, **xkwargs)
setattr(self, str(event), fun) setattr(self, str(event), fun)
grclass.__init__ = __repaired_init__ grclass.__init__ = __repaired_init__

View File

@@ -26,7 +26,7 @@ import modules.shared as shared
from modules import prompt_parser from modules import prompt_parser
from modules.infotext_utils import image_from_url_text, PasteField from modules.infotext_utils import image_from_url_text, PasteField
from modules_forge.forge_canvas.canvas import ForgeCanvas, canvas_head from modules_forge.forge_canvas.canvas import ForgeCanvas, canvas_head
from modules_forge import main_entry from modules_forge import main_entry, forge_space
create_setting_component = ui_settings.create_setting_component create_setting_component = ui_settings.create_setting_component
@@ -853,6 +853,9 @@ def create_ui():
extra_tabs.__exit__() extra_tabs.__exit__()
with gr.Blocks(analytics_enabled=False, head=canvas_head) as space_interface:
forge_space.main_entry()
scripts.scripts_current = None scripts.scripts_current = None
with gr.Blocks(analytics_enabled=False, head=canvas_head) as extras_interface: with gr.Blocks(analytics_enabled=False, head=canvas_head) as extras_interface:
@@ -891,6 +894,7 @@ def create_ui():
interfaces = [ interfaces = [
(txt2img_interface, "Txt2img", "txt2img"), (txt2img_interface, "Txt2img", "txt2img"),
(img2img_interface, "Img2img", "img2img"), (img2img_interface, "Img2img", "img2img"),
(space_interface, "Spaces", "space"),
(extras_interface, "Extras", "extras"), (extras_interface, "Extras", "extras"),
(pnginfo_interface, "PNG Info", "pnginfo"), (pnginfo_interface, "PNG Info", "pnginfo"),
(modelmerger_ui.blocks, "Checkpoint Merger", "modelmerger"), (modelmerger_ui.blocks, "Checkpoint Merger", "modelmerger"),

View File

@@ -0,0 +1,159 @@
import os
import sys
import uuid
import time
import gradio as gr
import importlib.util
import shutil
from gradio.context import Context
from threading import Thread
from huggingface_hub import snapshot_download
from backend import memory_management
spaces = []
def build_html(title, installed=False, url=None):
if not installed:
return f'<div>{title}</div><div style="color: grey;">Not Installed</div>'
if isinstance(url, str):
return f'<div>{title}</div><div>Currently Running: <a href="{url}" style="color: green;" target="_blank">{url}</a></div>'
else:
return f'<div>{title}</div><div style="color: grey;">Installed, Ready to Launch</div>'
class ForgeSpace:
def __init__(self, root_path, title, repo_id=None, repo_type='space', revision=None, **kwargs):
self.title = title
self.root_path = root_path
self.hf_path = os.path.join(root_path, 'huggingface_space_mirror')
self.repo_id = repo_id
self.repo_type = repo_type
self.revision = revision
self.is_running = False
self.gradio_metas = None
self.label = gr.HTML(build_html(title=title, url=None), elem_classes=['forge_space_label'])
self.btn_launch = gr.Button('Launch', elem_classes=['forge_space_btn'])
self.btn_terminate = gr.Button('Terminate', elem_classes=['forge_space_btn'])
self.btn_install = gr.Button('Install', elem_classes=['forge_space_btn'])
self.btn_uninstall = gr.Button('Uninstall', elem_classes=['forge_space_btn'])
comps = [
self.label,
self.btn_install,
self.btn_uninstall,
self.btn_launch,
self.btn_terminate
]
self.btn_launch.click(self.run, outputs=comps)
self.btn_terminate.click(self.terminate, outputs=comps)
self.btn_install.click(self.install, outputs=comps)
self.btn_uninstall.click(self.uninstall, outputs=comps)
Context.root_block.load(self.refresh_gradio, outputs=comps, queue=False, show_progress=False)
return
def refresh_gradio(self):
results = []
installed = os.path.exists(self.hf_path)
if isinstance(self.gradio_metas, tuple):
results.append(build_html(title=self.title, installed=installed, url=self.gradio_metas[1]))
else:
results.append(build_html(title=self.title, installed=installed, url=None))
results.append(gr.update(interactive=not installed))
results.append(gr.update(interactive=installed))
results.append(gr.update(interactive=installed and not self.is_running))
results.append(gr.update(interactive=installed and self.is_running))
return results
def install(self):
os.makedirs(self.hf_path, exist_ok=True)
if self.repo_id is None:
return self.refresh_gradio()
downloaded = snapshot_download(
repo_id=self.repo_id,
repo_type=self.repo_type,
revision=self.revision,
local_dir=self.hf_path,
force_download=True,
)
print(f'Downloaded: {downloaded}')
return self.refresh_gradio()
def uninstall(self):
shutil.rmtree(self.hf_path)
print(f'Deleted: {self.hf_path}')
return self.refresh_gradio()
def terminate(self):
self.is_running = False
while self.gradio_metas is not None:
time.sleep(0.1)
return self.refresh_gradio()
def run(self):
self.is_running = True
Thread(target=self.gradio_worker).start()
while self.gradio_metas is None:
time.sleep(0.1)
return self.refresh_gradio()
def gradio_worker(self):
memory_management.unload_all_models()
sys.path.insert(0, self.hf_path)
file_path = os.path.join(self.root_path, 'forge_app.py')
module_name = 'forge_space_' + str(uuid.uuid4()).replace('-', '_')
spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
demo = getattr(module, 'demo')
self.gradio_metas = demo.launch(inbrowser=True, prevent_thread_lock=True)
while self.is_running:
time.sleep(0.1)
demo.close()
self.gradio_metas = None
if module_name in sys.modules:
del sys.modules[module_name]
return
def main_entry():
global spaces
from modules.extensions import extensions
tagged_extensions = {}
for ex in extensions:
if ex.enabled and ex.is_forge_space:
tag = ex.space_meta['tag']
if tag not in tagged_extensions:
tagged_extensions[tag] = []
tagged_extensions[tag].append(ex)
for tag, exs in tagged_extensions.items():
with gr.Accordion(tag, open=True):
for ex in exs:
with gr.Row(equal_height=True):
space = ForgeSpace(root_path=ex.path, **ex.space_meta)
spaces.append(space)
return

View File

@@ -2,6 +2,7 @@ import os
import sys import sys
INITIALIZED = False
MONITOR_MODEL_MOVING = False MONITOR_MODEL_MOVING = False
@@ -25,6 +26,13 @@ def monitor_module_moving():
def initialize_forge(): def initialize_forge():
global INITIALIZED
if INITIALIZED:
return
INITIALIZED = True
sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(__file__)), 'packages_3rdparty')) sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(__file__)), 'packages_3rdparty'))
bad_list = ['--lowvram', '--medvram', '--medvram-sdxl'] bad_list = ['--lowvram', '--medvram', '--medvram-sdxl']
@@ -60,9 +68,6 @@ def initialize_forge():
from modules_forge.bnb_installer import try_install_bnb from modules_forge.bnb_installer import try_install_bnb
try_install_bnb() try_install_bnb()
import modules_forge.patch_basic
modules_forge.patch_basic.patch_all_basics()
from backend import stream from backend import stream
print('CUDA Using Stream:', stream.should_use_stream()) print('CUDA Using Stream:', stream.should_use_stream())
@@ -85,4 +90,8 @@ def initialize_forge():
if 'HF_HUB_CACHE' not in os.environ: if 'HF_HUB_CACHE' not in os.environ:
os.environ['HF_HUB_CACHE'] = diffusers_dir os.environ['HF_HUB_CACHE'] = diffusers_dir
import modules_forge.patch_basic
modules_forge.patch_basic.patch_all_basics()
return return

View File

@@ -6,6 +6,8 @@ import warnings
import gradio.networking import gradio.networking
import safetensors.torch import safetensors.torch
from tqdm import tqdm
def gradio_url_ok_fix(url: str) -> bool: def gradio_url_ok_fix(url: str) -> bool:
try: try:
@@ -55,7 +57,20 @@ def build_loaded(module, loader_name):
return return
def always_show_tqdm(*args, **kwargs):
kwargs['disable'] = False
if 'name' in kwargs:
del kwargs['name']
return tqdm(*args, **kwargs)
def patch_all_basics(): def patch_all_basics():
import logging
from huggingface_hub import file_download
file_download.tqdm = always_show_tqdm
from transformers.dynamic_module_utils import logger
logger.setLevel(logging.ERROR)
gradio.networking.url_ok = gradio_url_ok_fix gradio.networking.url_ok = gradio_url_ok_fix
build_loaded(safetensors.torch, 'load_file') build_loaded(safetensors.torch, 'load_file')
build_loaded(torch, 'load') build_loaded(torch, 'load')

View File

@@ -35,4 +35,6 @@ httpx==0.24.1
pillow-avif-plugin==1.4.3 pillow-avif-plugin==1.4.3
diffusers==0.29.2 diffusers==0.29.2
gradio_rangeslider==0.0.6 gradio_rangeslider==0.0.6
gradio_imageslider==0.0.20
loadimg==0.1.2
tqdm==4.66.1 tqdm==4.66.1

79
spaces.py Normal file
View File

@@ -0,0 +1,79 @@
from modules_forge.initialization import initialize_forge
initialize_forge()
import os
import torch
import inspect
from backend import memory_management
gpu = memory_management.get_torch_device()
class GPUObject:
def __init__(self):
self.module_list = []
def __enter__(self):
self.original_init = torch.nn.Module.__init__
self.original_to = torch.nn.Module.to
def patched_init(module, *args, **kwargs):
self.module_list.append(module)
return self.original_init(module, *args, **kwargs)
def patched_to(module, *args, **kwargs):
self.module_list.append(module)
return self.original_to(module, *args, **kwargs)
torch.nn.Module.__init__ = patched_init
torch.nn.Module.to = patched_to
return self
def __exit__(self, exc_type, exc_val, exc_tb):
torch.nn.Module.__init__ = self.original_init
torch.nn.Module.to = self.original_to
self.module_list = set(self.module_list)
self.to(device=torch.device('cpu'))
memory_management.soft_empty_cache()
return
def to(self, device):
for module in self.module_list:
module.to(device)
print(f'Forge Space: Moved {len(self.module_list)} Modules to {device}')
return self
def gpu(self):
self.to(device=gpu)
return self
def GPU(gpu_objects=None, manual_load=False):
gpu_objects = gpu_objects or []
def decorator(func):
def wrapper(*args, **kwargs):
print("Entering Forge Space GPU ...")
memory_management.unload_all_models()
if not manual_load:
for o in gpu_objects:
o.gpu()
result = func(*args, **kwargs)
print("Cleaning Forge Space GPU ...")
for o in gpu_objects:
o.to(device=torch.device('cpu'))
memory_management.soft_empty_cache()
return result
return wrapper
return decorator
def convert_root_path():
frame = inspect.currentframe().f_back
caller_file = frame.f_code.co_filename
caller_file = os.path.abspath(caller_file)
result = os.path.join(os.path.dirname(caller_file), 'huggingface_space_mirror')
return result + '/'

View File

@@ -1673,3 +1673,17 @@ body.resizing .resize-handle {
#quicksettings .gradio-slider span { #quicksettings .gradio-slider span {
padding-right: 5px; padding-right: 5px;
} }
.forge_space_label{
padding: 10px;
min-width: 60% !important;
margin: 1px;
border-width: 1px;
border-radius: 8px;
border-style: solid;
border-color: #6f6f6f;
}
.forge_space_btn{
min-width: 0 !important;
}

View File

@@ -11,8 +11,6 @@ from modules_forge.initialization import initialize_forge
from modules_forge import main_thread from modules_forge import main_thread
from modules_forge.forge_canvas.canvas import canvas_js_root_path
startup_timer = timer.startup_timer startup_timer = timer.startup_timer
startup_timer.record("launcher") startup_timer.record("launcher")
@@ -83,6 +81,8 @@ def webui_worker():
elif shared.opts.auto_launch_browser == "Local": elif shared.opts.auto_launch_browser == "Local":
auto_launch_browser = not cmd_opts.webui_is_non_local auto_launch_browser = not cmd_opts.webui_is_non_local
from modules_forge.forge_canvas.canvas import canvas_js_root_path
app, local_url, share_url = shared.demo.launch( app, local_url, share_url = shared.demo.launch(
share=cmd_opts.share, share=cmd_opts.share,
server_name=initialize_util.gradio_server_name(), server_name=initialize_util.gradio_server_name(),