mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-01-26 19:09:45 +00:00
Forge Space and BiRefNet
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,3 +1,4 @@
|
||||
huggingface_space_mirror/
|
||||
random_test.py
|
||||
__pycache__
|
||||
*.ckpt
|
||||
|
||||
67
extensions-builtin/forge_space_birefnet/forge_app.py
Normal file
67
extensions-builtin/forge_space_birefnet/forge_app.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import spaces
|
||||
import os
|
||||
import gradio as gr
|
||||
from gradio_imageslider import ImageSlider
|
||||
from loadimg import load_img
|
||||
from transformers import AutoModelForImageSegmentation
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
|
||||
torch.set_float32_matmul_precision(["high", "highest"][0])
|
||||
|
||||
os.environ['HOME'] = spaces.convert_root_path() + 'home'
|
||||
|
||||
with spaces.GPUObject() as birefnet_gpu_obj:
|
||||
birefnet = AutoModelForImageSegmentation.from_pretrained(
|
||||
"ZhengPeng7/BiRefNet", trust_remote_code=True
|
||||
)
|
||||
|
||||
transform_image = transforms.Compose(
|
||||
[
|
||||
transforms.Resize((1024, 1024)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@spaces.GPU(gpu_objects=[birefnet_gpu_obj])
|
||||
def fn(image):
|
||||
im = load_img(image, output_type="pil")
|
||||
im = im.convert("RGB")
|
||||
image_size = im.size
|
||||
origin = im.copy()
|
||||
image = load_img(im)
|
||||
input_images = transform_image(image).unsqueeze(0).to(spaces.gpu)
|
||||
# Prediction
|
||||
with torch.no_grad():
|
||||
preds = birefnet(input_images)[-1].sigmoid().cpu()
|
||||
pred = preds[0].squeeze()
|
||||
pred_pil = transforms.ToPILImage()(pred)
|
||||
mask = pred_pil.resize(image_size)
|
||||
image.putalpha(mask)
|
||||
return (image, origin)
|
||||
|
||||
|
||||
slider1 = ImageSlider(label="birefnet", type="pil")
|
||||
slider2 = ImageSlider(label="birefnet", type="pil")
|
||||
image = gr.Image(label="Upload an image")
|
||||
text = gr.Textbox(label="Paste an image URL")
|
||||
|
||||
|
||||
chameleon = load_img(spaces.convert_root_path() + "chameleon.jpg", output_type="pil")
|
||||
|
||||
url = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
|
||||
tab1 = gr.Interface(
|
||||
fn, inputs=image, outputs=slider1, examples=[chameleon], api_name="image", allow_flagging="never"
|
||||
)
|
||||
|
||||
tab2 = gr.Interface(fn, inputs=text, outputs=slider2, examples=[url], api_name="text", allow_flagging="never")
|
||||
|
||||
|
||||
demo = gr.TabbedInterface(
|
||||
[tab1, tab2], ["image", "text"], title="birefnet for background removal"
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch(inbrowser=True)
|
||||
5
extensions-builtin/forge_space_birefnet/space_meta.json
Normal file
5
extensions-builtin/forge_space_birefnet/space_meta.json
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"tag": "Image Processing: Matting, Saliency, and Background Removal",
|
||||
"title": "BiRefNet for Background Removal",
|
||||
"repo_id": "not-lain/background-removal"
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import dataclasses
|
||||
import os
|
||||
import threading
|
||||
import re
|
||||
import json
|
||||
|
||||
from modules import shared, errors, cache, scripts
|
||||
from modules.gitpython_hack import Repo
|
||||
@@ -124,6 +125,13 @@ class Extension:
|
||||
self.metadata = metadata if metadata else ExtensionMetadata(self.path, name.lower())
|
||||
self.canonical_name = metadata.canonical_name
|
||||
|
||||
self.is_forge_space = False
|
||||
self.space_meta = None
|
||||
|
||||
if os.path.exists(os.path.join(self.path, 'space_meta.json')) and os.path.exists(os.path.join(self.path, 'forge_app.py')):
|
||||
self.is_forge_space = True
|
||||
self.space_meta = json.load(open(os.path.join(self.path, 'space_meta.json'), 'rt', encoding='utf-8'))
|
||||
|
||||
def to_dict(self):
|
||||
return {x: getattr(self, x) for x in self.cached_fields}
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import inspect
|
||||
import types
|
||||
import warnings
|
||||
from functools import wraps
|
||||
|
||||
@@ -106,6 +107,25 @@ gradio_component_meta_create_or_modify_pyi_original = patches.patch(__file__, gr
|
||||
# this function is broken and does not seem to do anything useful
|
||||
gradio.component_meta.updateable = lambda x: x
|
||||
|
||||
|
||||
class EventWrapper:
|
||||
def __init__(self, replaced_event):
|
||||
self.replaced_event = replaced_event
|
||||
self.has_trigger = replaced_event.has_trigger
|
||||
self.event_name = replaced_event.event_name
|
||||
self.callback = replaced_event.callback
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
if '_js' in kwargs:
|
||||
kwargs['js'] = kwargs['_js']
|
||||
del kwargs['_js']
|
||||
return self.replaced_event(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def __self__(self):
|
||||
return self.replaced_event.__self__
|
||||
|
||||
|
||||
def repair(grclass):
|
||||
if not getattr(grclass, 'EVENTS', None):
|
||||
return
|
||||
@@ -129,13 +149,7 @@ def repair(grclass):
|
||||
|
||||
for event in self.EVENTS:
|
||||
replaced_event = getattr(self, str(event))
|
||||
|
||||
def fun(*xargs, _js=None, replaced_event=replaced_event, **xkwargs):
|
||||
if _js:
|
||||
xkwargs['js'] = _js
|
||||
|
||||
return replaced_event(*xargs, **xkwargs)
|
||||
|
||||
fun = EventWrapper(replaced_event)
|
||||
setattr(self, str(event), fun)
|
||||
|
||||
grclass.__init__ = __repaired_init__
|
||||
|
||||
@@ -26,7 +26,7 @@ import modules.shared as shared
|
||||
from modules import prompt_parser
|
||||
from modules.infotext_utils import image_from_url_text, PasteField
|
||||
from modules_forge.forge_canvas.canvas import ForgeCanvas, canvas_head
|
||||
from modules_forge import main_entry
|
||||
from modules_forge import main_entry, forge_space
|
||||
|
||||
|
||||
create_setting_component = ui_settings.create_setting_component
|
||||
@@ -853,6 +853,9 @@ def create_ui():
|
||||
|
||||
extra_tabs.__exit__()
|
||||
|
||||
with gr.Blocks(analytics_enabled=False, head=canvas_head) as space_interface:
|
||||
forge_space.main_entry()
|
||||
|
||||
scripts.scripts_current = None
|
||||
|
||||
with gr.Blocks(analytics_enabled=False, head=canvas_head) as extras_interface:
|
||||
@@ -891,6 +894,7 @@ def create_ui():
|
||||
interfaces = [
|
||||
(txt2img_interface, "Txt2img", "txt2img"),
|
||||
(img2img_interface, "Img2img", "img2img"),
|
||||
(space_interface, "Spaces", "space"),
|
||||
(extras_interface, "Extras", "extras"),
|
||||
(pnginfo_interface, "PNG Info", "pnginfo"),
|
||||
(modelmerger_ui.blocks, "Checkpoint Merger", "modelmerger"),
|
||||
|
||||
159
modules_forge/forge_space.py
Normal file
159
modules_forge/forge_space.py
Normal file
@@ -0,0 +1,159 @@
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
import time
|
||||
import gradio as gr
|
||||
import importlib.util
|
||||
import shutil
|
||||
|
||||
from gradio.context import Context
|
||||
from threading import Thread
|
||||
from huggingface_hub import snapshot_download
|
||||
from backend import memory_management
|
||||
|
||||
|
||||
spaces = []
|
||||
|
||||
|
||||
def build_html(title, installed=False, url=None):
|
||||
if not installed:
|
||||
return f'<div>{title}</div><div style="color: grey;">Not Installed</div>'
|
||||
|
||||
if isinstance(url, str):
|
||||
return f'<div>{title}</div><div>Currently Running: <a href="{url}" style="color: green;" target="_blank">{url}</a></div>'
|
||||
else:
|
||||
return f'<div>{title}</div><div style="color: grey;">Installed, Ready to Launch</div>'
|
||||
|
||||
|
||||
class ForgeSpace:
|
||||
def __init__(self, root_path, title, repo_id=None, repo_type='space', revision=None, **kwargs):
|
||||
self.title = title
|
||||
self.root_path = root_path
|
||||
self.hf_path = os.path.join(root_path, 'huggingface_space_mirror')
|
||||
self.repo_id = repo_id
|
||||
self.repo_type = repo_type
|
||||
self.revision = revision
|
||||
self.is_running = False
|
||||
self.gradio_metas = None
|
||||
|
||||
self.label = gr.HTML(build_html(title=title, url=None), elem_classes=['forge_space_label'])
|
||||
self.btn_launch = gr.Button('Launch', elem_classes=['forge_space_btn'])
|
||||
self.btn_terminate = gr.Button('Terminate', elem_classes=['forge_space_btn'])
|
||||
self.btn_install = gr.Button('Install', elem_classes=['forge_space_btn'])
|
||||
self.btn_uninstall = gr.Button('Uninstall', elem_classes=['forge_space_btn'])
|
||||
|
||||
comps = [
|
||||
self.label,
|
||||
self.btn_install,
|
||||
self.btn_uninstall,
|
||||
self.btn_launch,
|
||||
self.btn_terminate
|
||||
]
|
||||
|
||||
self.btn_launch.click(self.run, outputs=comps)
|
||||
self.btn_terminate.click(self.terminate, outputs=comps)
|
||||
self.btn_install.click(self.install, outputs=comps)
|
||||
self.btn_uninstall.click(self.uninstall, outputs=comps)
|
||||
Context.root_block.load(self.refresh_gradio, outputs=comps, queue=False, show_progress=False)
|
||||
|
||||
return
|
||||
|
||||
def refresh_gradio(self):
|
||||
results = []
|
||||
|
||||
installed = os.path.exists(self.hf_path)
|
||||
|
||||
if isinstance(self.gradio_metas, tuple):
|
||||
results.append(build_html(title=self.title, installed=installed, url=self.gradio_metas[1]))
|
||||
else:
|
||||
results.append(build_html(title=self.title, installed=installed, url=None))
|
||||
|
||||
results.append(gr.update(interactive=not installed))
|
||||
results.append(gr.update(interactive=installed))
|
||||
results.append(gr.update(interactive=installed and not self.is_running))
|
||||
results.append(gr.update(interactive=installed and self.is_running))
|
||||
return results
|
||||
|
||||
def install(self):
|
||||
os.makedirs(self.hf_path, exist_ok=True)
|
||||
|
||||
if self.repo_id is None:
|
||||
return self.refresh_gradio()
|
||||
|
||||
downloaded = snapshot_download(
|
||||
repo_id=self.repo_id,
|
||||
repo_type=self.repo_type,
|
||||
revision=self.revision,
|
||||
local_dir=self.hf_path,
|
||||
force_download=True,
|
||||
)
|
||||
|
||||
print(f'Downloaded: {downloaded}')
|
||||
return self.refresh_gradio()
|
||||
|
||||
def uninstall(self):
|
||||
shutil.rmtree(self.hf_path)
|
||||
print(f'Deleted: {self.hf_path}')
|
||||
return self.refresh_gradio()
|
||||
|
||||
def terminate(self):
|
||||
self.is_running = False
|
||||
while self.gradio_metas is not None:
|
||||
time.sleep(0.1)
|
||||
return self.refresh_gradio()
|
||||
|
||||
def run(self):
|
||||
self.is_running = True
|
||||
Thread(target=self.gradio_worker).start()
|
||||
while self.gradio_metas is None:
|
||||
time.sleep(0.1)
|
||||
return self.refresh_gradio()
|
||||
|
||||
def gradio_worker(self):
|
||||
memory_management.unload_all_models()
|
||||
sys.path.insert(0, self.hf_path)
|
||||
file_path = os.path.join(self.root_path, 'forge_app.py')
|
||||
module_name = 'forge_space_' + str(uuid.uuid4()).replace('-', '_')
|
||||
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
demo = getattr(module, 'demo')
|
||||
|
||||
self.gradio_metas = demo.launch(inbrowser=True, prevent_thread_lock=True)
|
||||
|
||||
while self.is_running:
|
||||
time.sleep(0.1)
|
||||
|
||||
demo.close()
|
||||
self.gradio_metas = None
|
||||
|
||||
if module_name in sys.modules:
|
||||
del sys.modules[module_name]
|
||||
|
||||
return
|
||||
|
||||
|
||||
def main_entry():
|
||||
global spaces
|
||||
|
||||
from modules.extensions import extensions
|
||||
|
||||
tagged_extensions = {}
|
||||
|
||||
for ex in extensions:
|
||||
if ex.enabled and ex.is_forge_space:
|
||||
tag = ex.space_meta['tag']
|
||||
|
||||
if tag not in tagged_extensions:
|
||||
tagged_extensions[tag] = []
|
||||
|
||||
tagged_extensions[tag].append(ex)
|
||||
|
||||
for tag, exs in tagged_extensions.items():
|
||||
with gr.Accordion(tag, open=True):
|
||||
for ex in exs:
|
||||
with gr.Row(equal_height=True):
|
||||
space = ForgeSpace(root_path=ex.path, **ex.space_meta)
|
||||
spaces.append(space)
|
||||
|
||||
return
|
||||
@@ -2,6 +2,7 @@ import os
|
||||
import sys
|
||||
|
||||
|
||||
INITIALIZED = False
|
||||
MONITOR_MODEL_MOVING = False
|
||||
|
||||
|
||||
@@ -25,6 +26,13 @@ def monitor_module_moving():
|
||||
|
||||
|
||||
def initialize_forge():
|
||||
global INITIALIZED
|
||||
|
||||
if INITIALIZED:
|
||||
return
|
||||
|
||||
INITIALIZED = True
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(__file__)), 'packages_3rdparty'))
|
||||
|
||||
bad_list = ['--lowvram', '--medvram', '--medvram-sdxl']
|
||||
@@ -60,9 +68,6 @@ def initialize_forge():
|
||||
from modules_forge.bnb_installer import try_install_bnb
|
||||
try_install_bnb()
|
||||
|
||||
import modules_forge.patch_basic
|
||||
modules_forge.patch_basic.patch_all_basics()
|
||||
|
||||
from backend import stream
|
||||
print('CUDA Using Stream:', stream.should_use_stream())
|
||||
|
||||
@@ -85,4 +90,8 @@ def initialize_forge():
|
||||
|
||||
if 'HF_HUB_CACHE' not in os.environ:
|
||||
os.environ['HF_HUB_CACHE'] = diffusers_dir
|
||||
|
||||
import modules_forge.patch_basic
|
||||
modules_forge.patch_basic.patch_all_basics()
|
||||
|
||||
return
|
||||
|
||||
@@ -6,6 +6,8 @@ import warnings
|
||||
import gradio.networking
|
||||
import safetensors.torch
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def gradio_url_ok_fix(url: str) -> bool:
|
||||
try:
|
||||
@@ -55,7 +57,20 @@ def build_loaded(module, loader_name):
|
||||
return
|
||||
|
||||
|
||||
def always_show_tqdm(*args, **kwargs):
|
||||
kwargs['disable'] = False
|
||||
if 'name' in kwargs:
|
||||
del kwargs['name']
|
||||
return tqdm(*args, **kwargs)
|
||||
|
||||
|
||||
def patch_all_basics():
|
||||
import logging
|
||||
from huggingface_hub import file_download
|
||||
file_download.tqdm = always_show_tqdm
|
||||
from transformers.dynamic_module_utils import logger
|
||||
logger.setLevel(logging.ERROR)
|
||||
|
||||
gradio.networking.url_ok = gradio_url_ok_fix
|
||||
build_loaded(safetensors.torch, 'load_file')
|
||||
build_loaded(torch, 'load')
|
||||
|
||||
@@ -35,4 +35,6 @@ httpx==0.24.1
|
||||
pillow-avif-plugin==1.4.3
|
||||
diffusers==0.29.2
|
||||
gradio_rangeslider==0.0.6
|
||||
gradio_imageslider==0.0.20
|
||||
loadimg==0.1.2
|
||||
tqdm==4.66.1
|
||||
|
||||
79
spaces.py
Normal file
79
spaces.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from modules_forge.initialization import initialize_forge
|
||||
|
||||
initialize_forge()
|
||||
|
||||
import os
|
||||
import torch
|
||||
import inspect
|
||||
|
||||
from backend import memory_management
|
||||
|
||||
|
||||
gpu = memory_management.get_torch_device()
|
||||
|
||||
|
||||
class GPUObject:
|
||||
def __init__(self):
|
||||
self.module_list = []
|
||||
|
||||
def __enter__(self):
|
||||
self.original_init = torch.nn.Module.__init__
|
||||
self.original_to = torch.nn.Module.to
|
||||
|
||||
def patched_init(module, *args, **kwargs):
|
||||
self.module_list.append(module)
|
||||
return self.original_init(module, *args, **kwargs)
|
||||
|
||||
def patched_to(module, *args, **kwargs):
|
||||
self.module_list.append(module)
|
||||
return self.original_to(module, *args, **kwargs)
|
||||
|
||||
torch.nn.Module.__init__ = patched_init
|
||||
torch.nn.Module.to = patched_to
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
torch.nn.Module.__init__ = self.original_init
|
||||
torch.nn.Module.to = self.original_to
|
||||
self.module_list = set(self.module_list)
|
||||
self.to(device=torch.device('cpu'))
|
||||
memory_management.soft_empty_cache()
|
||||
return
|
||||
|
||||
def to(self, device):
|
||||
for module in self.module_list:
|
||||
module.to(device)
|
||||
print(f'Forge Space: Moved {len(self.module_list)} Modules to {device}')
|
||||
return self
|
||||
|
||||
def gpu(self):
|
||||
self.to(device=gpu)
|
||||
return self
|
||||
|
||||
|
||||
def GPU(gpu_objects=None, manual_load=False):
|
||||
gpu_objects = gpu_objects or []
|
||||
|
||||
def decorator(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
print("Entering Forge Space GPU ...")
|
||||
memory_management.unload_all_models()
|
||||
if not manual_load:
|
||||
for o in gpu_objects:
|
||||
o.gpu()
|
||||
result = func(*args, **kwargs)
|
||||
print("Cleaning Forge Space GPU ...")
|
||||
for o in gpu_objects:
|
||||
o.to(device=torch.device('cpu'))
|
||||
memory_management.soft_empty_cache()
|
||||
return result
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
def convert_root_path():
|
||||
frame = inspect.currentframe().f_back
|
||||
caller_file = frame.f_code.co_filename
|
||||
caller_file = os.path.abspath(caller_file)
|
||||
result = os.path.join(os.path.dirname(caller_file), 'huggingface_space_mirror')
|
||||
return result + '/'
|
||||
14
style.css
14
style.css
@@ -1673,3 +1673,17 @@ body.resizing .resize-handle {
|
||||
#quicksettings .gradio-slider span {
|
||||
padding-right: 5px;
|
||||
}
|
||||
|
||||
.forge_space_label{
|
||||
padding: 10px;
|
||||
min-width: 60% !important;
|
||||
margin: 1px;
|
||||
border-width: 1px;
|
||||
border-radius: 8px;
|
||||
border-style: solid;
|
||||
border-color: #6f6f6f;
|
||||
}
|
||||
|
||||
.forge_space_btn{
|
||||
min-width: 0 !important;
|
||||
}
|
||||
|
||||
4
webui.py
4
webui.py
@@ -11,8 +11,6 @@ from modules_forge.initialization import initialize_forge
|
||||
from modules_forge import main_thread
|
||||
|
||||
|
||||
from modules_forge.forge_canvas.canvas import canvas_js_root_path
|
||||
|
||||
startup_timer = timer.startup_timer
|
||||
startup_timer.record("launcher")
|
||||
|
||||
@@ -83,6 +81,8 @@ def webui_worker():
|
||||
elif shared.opts.auto_launch_browser == "Local":
|
||||
auto_launch_browser = not cmd_opts.webui_is_non_local
|
||||
|
||||
from modules_forge.forge_canvas.canvas import canvas_js_root_path
|
||||
|
||||
app, local_url, share_url = shared.demo.launch(
|
||||
share=cmd_opts.share,
|
||||
server_name=initialize_util.gradio_server_name(),
|
||||
|
||||
Reference in New Issue
Block a user