diff --git a/.gitignore b/.gitignore
index 53836027..875acee8 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,4 @@
+huggingface_space_mirror/
random_test.py
__pycache__
*.ckpt
diff --git a/extensions-builtin/forge_space_birefnet/forge_app.py b/extensions-builtin/forge_space_birefnet/forge_app.py
new file mode 100644
index 00000000..6199819e
--- /dev/null
+++ b/extensions-builtin/forge_space_birefnet/forge_app.py
@@ -0,0 +1,67 @@
+import spaces
+import os
+import gradio as gr
+from gradio_imageslider import ImageSlider
+from loadimg import load_img
+from transformers import AutoModelForImageSegmentation
+import torch
+from torchvision import transforms
+
+torch.set_float32_matmul_precision(["high", "highest"][0])
+
+os.environ['HOME'] = spaces.convert_root_path() + 'home'
+
+with spaces.GPUObject() as birefnet_gpu_obj:
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
+ "ZhengPeng7/BiRefNet", trust_remote_code=True
+ )
+
+transform_image = transforms.Compose(
+ [
+ transforms.Resize((1024, 1024)),
+ transforms.ToTensor(),
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
+ ]
+)
+
+
+@spaces.GPU(gpu_objects=[birefnet_gpu_obj])
+def fn(image):
+ im = load_img(image, output_type="pil")
+ im = im.convert("RGB")
+ image_size = im.size
+ origin = im.copy()
+ image = load_img(im)
+ input_images = transform_image(image).unsqueeze(0).to(spaces.gpu)
+ # Prediction
+ with torch.no_grad():
+ preds = birefnet(input_images)[-1].sigmoid().cpu()
+ pred = preds[0].squeeze()
+ pred_pil = transforms.ToPILImage()(pred)
+ mask = pred_pil.resize(image_size)
+ image.putalpha(mask)
+ return (image, origin)
+
+
+slider1 = ImageSlider(label="birefnet", type="pil")
+slider2 = ImageSlider(label="birefnet", type="pil")
+image = gr.Image(label="Upload an image")
+text = gr.Textbox(label="Paste an image URL")
+
+
+chameleon = load_img(spaces.convert_root_path() + "chameleon.jpg", output_type="pil")
+
+url = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
+tab1 = gr.Interface(
+ fn, inputs=image, outputs=slider1, examples=[chameleon], api_name="image", allow_flagging="never"
+)
+
+tab2 = gr.Interface(fn, inputs=text, outputs=slider2, examples=[url], api_name="text", allow_flagging="never")
+
+
+demo = gr.TabbedInterface(
+ [tab1, tab2], ["image", "text"], title="birefnet for background removal"
+)
+
+if __name__ == "__main__":
+ demo.launch(inbrowser=True)
diff --git a/extensions-builtin/forge_space_birefnet/space_meta.json b/extensions-builtin/forge_space_birefnet/space_meta.json
new file mode 100644
index 00000000..03702128
--- /dev/null
+++ b/extensions-builtin/forge_space_birefnet/space_meta.json
@@ -0,0 +1,5 @@
+{
+ "tag": "Image Processing: Matting, Saliency, and Background Removal",
+ "title": "BiRefNet for Background Removal",
+ "repo_id": "not-lain/background-removal"
+}
diff --git a/modules/extensions.py b/modules/extensions.py
index 715a864c..76f54fbb 100644
--- a/modules/extensions.py
+++ b/modules/extensions.py
@@ -5,6 +5,7 @@ import dataclasses
import os
import threading
import re
+import json
from modules import shared, errors, cache, scripts
from modules.gitpython_hack import Repo
@@ -124,6 +125,13 @@ class Extension:
self.metadata = metadata if metadata else ExtensionMetadata(self.path, name.lower())
self.canonical_name = metadata.canonical_name
+ self.is_forge_space = False
+ self.space_meta = None
+
+ if os.path.exists(os.path.join(self.path, 'space_meta.json')) and os.path.exists(os.path.join(self.path, 'forge_app.py')):
+ self.is_forge_space = True
+ self.space_meta = json.load(open(os.path.join(self.path, 'space_meta.json'), 'rt', encoding='utf-8'))
+
def to_dict(self):
return {x: getattr(self, x) for x in self.cached_fields}
diff --git a/modules/gradio_extensions.py b/modules/gradio_extensions.py
index 84414f6e..755997b7 100644
--- a/modules/gradio_extensions.py
+++ b/modules/gradio_extensions.py
@@ -1,4 +1,5 @@
import inspect
+import types
import warnings
from functools import wraps
@@ -106,6 +107,25 @@ gradio_component_meta_create_or_modify_pyi_original = patches.patch(__file__, gr
# this function is broken and does not seem to do anything useful
gradio.component_meta.updateable = lambda x: x
+
+class EventWrapper:
+ def __init__(self, replaced_event):
+ self.replaced_event = replaced_event
+ self.has_trigger = replaced_event.has_trigger
+ self.event_name = replaced_event.event_name
+ self.callback = replaced_event.callback
+
+ def __call__(self, *args, **kwargs):
+ if '_js' in kwargs:
+ kwargs['js'] = kwargs['_js']
+ del kwargs['_js']
+ return self.replaced_event(*args, **kwargs)
+
+ @property
+ def __self__(self):
+ return self.replaced_event.__self__
+
+
def repair(grclass):
if not getattr(grclass, 'EVENTS', None):
return
@@ -129,13 +149,7 @@ def repair(grclass):
for event in self.EVENTS:
replaced_event = getattr(self, str(event))
-
- def fun(*xargs, _js=None, replaced_event=replaced_event, **xkwargs):
- if _js:
- xkwargs['js'] = _js
-
- return replaced_event(*xargs, **xkwargs)
-
+ fun = EventWrapper(replaced_event)
setattr(self, str(event), fun)
grclass.__init__ = __repaired_init__
diff --git a/modules/ui.py b/modules/ui.py
index 618e9da9..b62f45ec 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -26,7 +26,7 @@ import modules.shared as shared
from modules import prompt_parser
from modules.infotext_utils import image_from_url_text, PasteField
from modules_forge.forge_canvas.canvas import ForgeCanvas, canvas_head
-from modules_forge import main_entry
+from modules_forge import main_entry, forge_space
create_setting_component = ui_settings.create_setting_component
@@ -853,6 +853,9 @@ def create_ui():
extra_tabs.__exit__()
+ with gr.Blocks(analytics_enabled=False, head=canvas_head) as space_interface:
+ forge_space.main_entry()
+
scripts.scripts_current = None
with gr.Blocks(analytics_enabled=False, head=canvas_head) as extras_interface:
@@ -891,6 +894,7 @@ def create_ui():
interfaces = [
(txt2img_interface, "Txt2img", "txt2img"),
(img2img_interface, "Img2img", "img2img"),
+ (space_interface, "Spaces", "space"),
(extras_interface, "Extras", "extras"),
(pnginfo_interface, "PNG Info", "pnginfo"),
(modelmerger_ui.blocks, "Checkpoint Merger", "modelmerger"),
diff --git a/modules_forge/forge_space.py b/modules_forge/forge_space.py
new file mode 100644
index 00000000..0782ec14
--- /dev/null
+++ b/modules_forge/forge_space.py
@@ -0,0 +1,159 @@
+import os
+import sys
+import uuid
+import time
+import gradio as gr
+import importlib.util
+import shutil
+
+from gradio.context import Context
+from threading import Thread
+from huggingface_hub import snapshot_download
+from backend import memory_management
+
+
+spaces = []
+
+
+def build_html(title, installed=False, url=None):
+ if not installed:
+ return f'
{title}
Not Installed
'
+
+ if isinstance(url, str):
+ return f'{title}
'
+ else:
+ return f'{title}
Installed, Ready to Launch
'
+
+
+class ForgeSpace:
+ def __init__(self, root_path, title, repo_id=None, repo_type='space', revision=None, **kwargs):
+ self.title = title
+ self.root_path = root_path
+ self.hf_path = os.path.join(root_path, 'huggingface_space_mirror')
+ self.repo_id = repo_id
+ self.repo_type = repo_type
+ self.revision = revision
+ self.is_running = False
+ self.gradio_metas = None
+
+ self.label = gr.HTML(build_html(title=title, url=None), elem_classes=['forge_space_label'])
+ self.btn_launch = gr.Button('Launch', elem_classes=['forge_space_btn'])
+ self.btn_terminate = gr.Button('Terminate', elem_classes=['forge_space_btn'])
+ self.btn_install = gr.Button('Install', elem_classes=['forge_space_btn'])
+ self.btn_uninstall = gr.Button('Uninstall', elem_classes=['forge_space_btn'])
+
+ comps = [
+ self.label,
+ self.btn_install,
+ self.btn_uninstall,
+ self.btn_launch,
+ self.btn_terminate
+ ]
+
+ self.btn_launch.click(self.run, outputs=comps)
+ self.btn_terminate.click(self.terminate, outputs=comps)
+ self.btn_install.click(self.install, outputs=comps)
+ self.btn_uninstall.click(self.uninstall, outputs=comps)
+ Context.root_block.load(self.refresh_gradio, outputs=comps, queue=False, show_progress=False)
+
+ return
+
+ def refresh_gradio(self):
+ results = []
+
+ installed = os.path.exists(self.hf_path)
+
+ if isinstance(self.gradio_metas, tuple):
+ results.append(build_html(title=self.title, installed=installed, url=self.gradio_metas[1]))
+ else:
+ results.append(build_html(title=self.title, installed=installed, url=None))
+
+ results.append(gr.update(interactive=not installed))
+ results.append(gr.update(interactive=installed))
+ results.append(gr.update(interactive=installed and not self.is_running))
+ results.append(gr.update(interactive=installed and self.is_running))
+ return results
+
+ def install(self):
+ os.makedirs(self.hf_path, exist_ok=True)
+
+ if self.repo_id is None:
+ return self.refresh_gradio()
+
+ downloaded = snapshot_download(
+ repo_id=self.repo_id,
+ repo_type=self.repo_type,
+ revision=self.revision,
+ local_dir=self.hf_path,
+ force_download=True,
+ )
+
+ print(f'Downloaded: {downloaded}')
+ return self.refresh_gradio()
+
+ def uninstall(self):
+ shutil.rmtree(self.hf_path)
+ print(f'Deleted: {self.hf_path}')
+ return self.refresh_gradio()
+
+ def terminate(self):
+ self.is_running = False
+ while self.gradio_metas is not None:
+ time.sleep(0.1)
+ return self.refresh_gradio()
+
+ def run(self):
+ self.is_running = True
+ Thread(target=self.gradio_worker).start()
+ while self.gradio_metas is None:
+ time.sleep(0.1)
+ return self.refresh_gradio()
+
+ def gradio_worker(self):
+ memory_management.unload_all_models()
+ sys.path.insert(0, self.hf_path)
+ file_path = os.path.join(self.root_path, 'forge_app.py')
+ module_name = 'forge_space_' + str(uuid.uuid4()).replace('-', '_')
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ module = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(module)
+ demo = getattr(module, 'demo')
+
+ self.gradio_metas = demo.launch(inbrowser=True, prevent_thread_lock=True)
+
+ while self.is_running:
+ time.sleep(0.1)
+
+ demo.close()
+ self.gradio_metas = None
+
+ if module_name in sys.modules:
+ del sys.modules[module_name]
+
+ return
+
+
+def main_entry():
+ global spaces
+
+ from modules.extensions import extensions
+
+ tagged_extensions = {}
+
+ for ex in extensions:
+ if ex.enabled and ex.is_forge_space:
+ tag = ex.space_meta['tag']
+
+ if tag not in tagged_extensions:
+ tagged_extensions[tag] = []
+
+ tagged_extensions[tag].append(ex)
+
+ for tag, exs in tagged_extensions.items():
+ with gr.Accordion(tag, open=True):
+ for ex in exs:
+ with gr.Row(equal_height=True):
+ space = ForgeSpace(root_path=ex.path, **ex.space_meta)
+ spaces.append(space)
+
+ return
diff --git a/modules_forge/initialization.py b/modules_forge/initialization.py
index 800909f8..0420dea4 100644
--- a/modules_forge/initialization.py
+++ b/modules_forge/initialization.py
@@ -2,6 +2,7 @@ import os
import sys
+INITIALIZED = False
MONITOR_MODEL_MOVING = False
@@ -25,6 +26,13 @@ def monitor_module_moving():
def initialize_forge():
+ global INITIALIZED
+
+ if INITIALIZED:
+ return
+
+ INITIALIZED = True
+
sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(__file__)), 'packages_3rdparty'))
bad_list = ['--lowvram', '--medvram', '--medvram-sdxl']
@@ -60,9 +68,6 @@ def initialize_forge():
from modules_forge.bnb_installer import try_install_bnb
try_install_bnb()
- import modules_forge.patch_basic
- modules_forge.patch_basic.patch_all_basics()
-
from backend import stream
print('CUDA Using Stream:', stream.should_use_stream())
@@ -85,4 +90,8 @@ def initialize_forge():
if 'HF_HUB_CACHE' not in os.environ:
os.environ['HF_HUB_CACHE'] = diffusers_dir
+
+ import modules_forge.patch_basic
+ modules_forge.patch_basic.patch_all_basics()
+
return
diff --git a/modules_forge/patch_basic.py b/modules_forge/patch_basic.py
index 87048a20..d22bc003 100644
--- a/modules_forge/patch_basic.py
+++ b/modules_forge/patch_basic.py
@@ -6,6 +6,8 @@ import warnings
import gradio.networking
import safetensors.torch
+from tqdm import tqdm
+
def gradio_url_ok_fix(url: str) -> bool:
try:
@@ -55,7 +57,20 @@ def build_loaded(module, loader_name):
return
+def always_show_tqdm(*args, **kwargs):
+ kwargs['disable'] = False
+ if 'name' in kwargs:
+ del kwargs['name']
+ return tqdm(*args, **kwargs)
+
+
def patch_all_basics():
+ import logging
+ from huggingface_hub import file_download
+ file_download.tqdm = always_show_tqdm
+ from transformers.dynamic_module_utils import logger
+ logger.setLevel(logging.ERROR)
+
gradio.networking.url_ok = gradio_url_ok_fix
build_loaded(safetensors.torch, 'load_file')
build_loaded(torch, 'load')
diff --git a/requirements_versions.txt b/requirements_versions.txt
index 3c608adb..1d3f04cf 100644
--- a/requirements_versions.txt
+++ b/requirements_versions.txt
@@ -35,4 +35,6 @@ httpx==0.24.1
pillow-avif-plugin==1.4.3
diffusers==0.29.2
gradio_rangeslider==0.0.6
+gradio_imageslider==0.0.20
+loadimg==0.1.2
tqdm==4.66.1
diff --git a/spaces.py b/spaces.py
new file mode 100644
index 00000000..bb9dfc6f
--- /dev/null
+++ b/spaces.py
@@ -0,0 +1,79 @@
+from modules_forge.initialization import initialize_forge
+
+initialize_forge()
+
+import os
+import torch
+import inspect
+
+from backend import memory_management
+
+
+gpu = memory_management.get_torch_device()
+
+
+class GPUObject:
+ def __init__(self):
+ self.module_list = []
+
+ def __enter__(self):
+ self.original_init = torch.nn.Module.__init__
+ self.original_to = torch.nn.Module.to
+
+ def patched_init(module, *args, **kwargs):
+ self.module_list.append(module)
+ return self.original_init(module, *args, **kwargs)
+
+ def patched_to(module, *args, **kwargs):
+ self.module_list.append(module)
+ return self.original_to(module, *args, **kwargs)
+
+ torch.nn.Module.__init__ = patched_init
+ torch.nn.Module.to = patched_to
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ torch.nn.Module.__init__ = self.original_init
+ torch.nn.Module.to = self.original_to
+ self.module_list = set(self.module_list)
+ self.to(device=torch.device('cpu'))
+ memory_management.soft_empty_cache()
+ return
+
+ def to(self, device):
+ for module in self.module_list:
+ module.to(device)
+ print(f'Forge Space: Moved {len(self.module_list)} Modules to {device}')
+ return self
+
+ def gpu(self):
+ self.to(device=gpu)
+ return self
+
+
+def GPU(gpu_objects=None, manual_load=False):
+ gpu_objects = gpu_objects or []
+
+ def decorator(func):
+ def wrapper(*args, **kwargs):
+ print("Entering Forge Space GPU ...")
+ memory_management.unload_all_models()
+ if not manual_load:
+ for o in gpu_objects:
+ o.gpu()
+ result = func(*args, **kwargs)
+ print("Cleaning Forge Space GPU ...")
+ for o in gpu_objects:
+ o.to(device=torch.device('cpu'))
+ memory_management.soft_empty_cache()
+ return result
+ return wrapper
+ return decorator
+
+
+def convert_root_path():
+ frame = inspect.currentframe().f_back
+ caller_file = frame.f_code.co_filename
+ caller_file = os.path.abspath(caller_file)
+ result = os.path.join(os.path.dirname(caller_file), 'huggingface_space_mirror')
+ return result + '/'
diff --git a/style.css b/style.css
index 9cc2534f..29e7418d 100644
--- a/style.css
+++ b/style.css
@@ -1673,3 +1673,17 @@ body.resizing .resize-handle {
#quicksettings .gradio-slider span {
padding-right: 5px;
}
+
+.forge_space_label{
+ padding: 10px;
+ min-width: 60% !important;
+ margin: 1px;
+ border-width: 1px;
+ border-radius: 8px;
+ border-style: solid;
+ border-color: #6f6f6f;
+}
+
+.forge_space_btn{
+ min-width: 0 !important;
+}
diff --git a/webui.py b/webui.py
index 921edce9..ab0e6c2f 100644
--- a/webui.py
+++ b/webui.py
@@ -11,8 +11,6 @@ from modules_forge.initialization import initialize_forge
from modules_forge import main_thread
-from modules_forge.forge_canvas.canvas import canvas_js_root_path
-
startup_timer = timer.startup_timer
startup_timer.record("launcher")
@@ -83,6 +81,8 @@ def webui_worker():
elif shared.opts.auto_launch_browser == "Local":
auto_launch_browser = not cmd_opts.webui_is_non_local
+ from modules_forge.forge_canvas.canvas import canvas_js_root_path
+
app, local_url, share_url = shared.demo.launch(
share=cmd_opts.share,
server_name=initialize_util.gradio_server_name(),