diff --git a/.gitignore b/.gitignore index 53836027..875acee8 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +huggingface_space_mirror/ random_test.py __pycache__ *.ckpt diff --git a/extensions-builtin/forge_space_birefnet/forge_app.py b/extensions-builtin/forge_space_birefnet/forge_app.py new file mode 100644 index 00000000..6199819e --- /dev/null +++ b/extensions-builtin/forge_space_birefnet/forge_app.py @@ -0,0 +1,67 @@ +import spaces +import os +import gradio as gr +from gradio_imageslider import ImageSlider +from loadimg import load_img +from transformers import AutoModelForImageSegmentation +import torch +from torchvision import transforms + +torch.set_float32_matmul_precision(["high", "highest"][0]) + +os.environ['HOME'] = spaces.convert_root_path() + 'home' + +with spaces.GPUObject() as birefnet_gpu_obj: + birefnet = AutoModelForImageSegmentation.from_pretrained( + "ZhengPeng7/BiRefNet", trust_remote_code=True + ) + +transform_image = transforms.Compose( + [ + transforms.Resize((1024, 1024)), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ] +) + + +@spaces.GPU(gpu_objects=[birefnet_gpu_obj]) +def fn(image): + im = load_img(image, output_type="pil") + im = im.convert("RGB") + image_size = im.size + origin = im.copy() + image = load_img(im) + input_images = transform_image(image).unsqueeze(0).to(spaces.gpu) + # Prediction + with torch.no_grad(): + preds = birefnet(input_images)[-1].sigmoid().cpu() + pred = preds[0].squeeze() + pred_pil = transforms.ToPILImage()(pred) + mask = pred_pil.resize(image_size) + image.putalpha(mask) + return (image, origin) + + +slider1 = ImageSlider(label="birefnet", type="pil") +slider2 = ImageSlider(label="birefnet", type="pil") +image = gr.Image(label="Upload an image") +text = gr.Textbox(label="Paste an image URL") + + +chameleon = load_img(spaces.convert_root_path() + "chameleon.jpg", output_type="pil") + +url = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg" +tab1 = gr.Interface( + fn, inputs=image, outputs=slider1, examples=[chameleon], api_name="image", allow_flagging="never" +) + +tab2 = gr.Interface(fn, inputs=text, outputs=slider2, examples=[url], api_name="text", allow_flagging="never") + + +demo = gr.TabbedInterface( + [tab1, tab2], ["image", "text"], title="birefnet for background removal" +) + +if __name__ == "__main__": + demo.launch(inbrowser=True) diff --git a/extensions-builtin/forge_space_birefnet/space_meta.json b/extensions-builtin/forge_space_birefnet/space_meta.json new file mode 100644 index 00000000..03702128 --- /dev/null +++ b/extensions-builtin/forge_space_birefnet/space_meta.json @@ -0,0 +1,5 @@ +{ + "tag": "Image Processing: Matting, Saliency, and Background Removal", + "title": "BiRefNet for Background Removal", + "repo_id": "not-lain/background-removal" +} diff --git a/modules/extensions.py b/modules/extensions.py index 715a864c..76f54fbb 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -5,6 +5,7 @@ import dataclasses import os import threading import re +import json from modules import shared, errors, cache, scripts from modules.gitpython_hack import Repo @@ -124,6 +125,13 @@ class Extension: self.metadata = metadata if metadata else ExtensionMetadata(self.path, name.lower()) self.canonical_name = metadata.canonical_name + self.is_forge_space = False + self.space_meta = None + + if os.path.exists(os.path.join(self.path, 'space_meta.json')) and os.path.exists(os.path.join(self.path, 'forge_app.py')): + self.is_forge_space = True + self.space_meta = json.load(open(os.path.join(self.path, 'space_meta.json'), 'rt', encoding='utf-8')) + def to_dict(self): return {x: getattr(self, x) for x in self.cached_fields} diff --git a/modules/gradio_extensions.py b/modules/gradio_extensions.py index 84414f6e..755997b7 100644 --- a/modules/gradio_extensions.py +++ b/modules/gradio_extensions.py @@ -1,4 +1,5 @@ import inspect +import types import warnings from functools import wraps @@ -106,6 +107,25 @@ gradio_component_meta_create_or_modify_pyi_original = patches.patch(__file__, gr # this function is broken and does not seem to do anything useful gradio.component_meta.updateable = lambda x: x + +class EventWrapper: + def __init__(self, replaced_event): + self.replaced_event = replaced_event + self.has_trigger = replaced_event.has_trigger + self.event_name = replaced_event.event_name + self.callback = replaced_event.callback + + def __call__(self, *args, **kwargs): + if '_js' in kwargs: + kwargs['js'] = kwargs['_js'] + del kwargs['_js'] + return self.replaced_event(*args, **kwargs) + + @property + def __self__(self): + return self.replaced_event.__self__ + + def repair(grclass): if not getattr(grclass, 'EVENTS', None): return @@ -129,13 +149,7 @@ def repair(grclass): for event in self.EVENTS: replaced_event = getattr(self, str(event)) - - def fun(*xargs, _js=None, replaced_event=replaced_event, **xkwargs): - if _js: - xkwargs['js'] = _js - - return replaced_event(*xargs, **xkwargs) - + fun = EventWrapper(replaced_event) setattr(self, str(event), fun) grclass.__init__ = __repaired_init__ diff --git a/modules/ui.py b/modules/ui.py index 618e9da9..b62f45ec 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -26,7 +26,7 @@ import modules.shared as shared from modules import prompt_parser from modules.infotext_utils import image_from_url_text, PasteField from modules_forge.forge_canvas.canvas import ForgeCanvas, canvas_head -from modules_forge import main_entry +from modules_forge import main_entry, forge_space create_setting_component = ui_settings.create_setting_component @@ -853,6 +853,9 @@ def create_ui(): extra_tabs.__exit__() + with gr.Blocks(analytics_enabled=False, head=canvas_head) as space_interface: + forge_space.main_entry() + scripts.scripts_current = None with gr.Blocks(analytics_enabled=False, head=canvas_head) as extras_interface: @@ -891,6 +894,7 @@ def create_ui(): interfaces = [ (txt2img_interface, "Txt2img", "txt2img"), (img2img_interface, "Img2img", "img2img"), + (space_interface, "Spaces", "space"), (extras_interface, "Extras", "extras"), (pnginfo_interface, "PNG Info", "pnginfo"), (modelmerger_ui.blocks, "Checkpoint Merger", "modelmerger"), diff --git a/modules_forge/forge_space.py b/modules_forge/forge_space.py new file mode 100644 index 00000000..0782ec14 --- /dev/null +++ b/modules_forge/forge_space.py @@ -0,0 +1,159 @@ +import os +import sys +import uuid +import time +import gradio as gr +import importlib.util +import shutil + +from gradio.context import Context +from threading import Thread +from huggingface_hub import snapshot_download +from backend import memory_management + + +spaces = [] + + +def build_html(title, installed=False, url=None): + if not installed: + return f'
{title}
Not Installed
' + + if isinstance(url, str): + return f'
{title}
Currently Running: {url}
' + else: + return f'
{title}
Installed, Ready to Launch
' + + +class ForgeSpace: + def __init__(self, root_path, title, repo_id=None, repo_type='space', revision=None, **kwargs): + self.title = title + self.root_path = root_path + self.hf_path = os.path.join(root_path, 'huggingface_space_mirror') + self.repo_id = repo_id + self.repo_type = repo_type + self.revision = revision + self.is_running = False + self.gradio_metas = None + + self.label = gr.HTML(build_html(title=title, url=None), elem_classes=['forge_space_label']) + self.btn_launch = gr.Button('Launch', elem_classes=['forge_space_btn']) + self.btn_terminate = gr.Button('Terminate', elem_classes=['forge_space_btn']) + self.btn_install = gr.Button('Install', elem_classes=['forge_space_btn']) + self.btn_uninstall = gr.Button('Uninstall', elem_classes=['forge_space_btn']) + + comps = [ + self.label, + self.btn_install, + self.btn_uninstall, + self.btn_launch, + self.btn_terminate + ] + + self.btn_launch.click(self.run, outputs=comps) + self.btn_terminate.click(self.terminate, outputs=comps) + self.btn_install.click(self.install, outputs=comps) + self.btn_uninstall.click(self.uninstall, outputs=comps) + Context.root_block.load(self.refresh_gradio, outputs=comps, queue=False, show_progress=False) + + return + + def refresh_gradio(self): + results = [] + + installed = os.path.exists(self.hf_path) + + if isinstance(self.gradio_metas, tuple): + results.append(build_html(title=self.title, installed=installed, url=self.gradio_metas[1])) + else: + results.append(build_html(title=self.title, installed=installed, url=None)) + + results.append(gr.update(interactive=not installed)) + results.append(gr.update(interactive=installed)) + results.append(gr.update(interactive=installed and not self.is_running)) + results.append(gr.update(interactive=installed and self.is_running)) + return results + + def install(self): + os.makedirs(self.hf_path, exist_ok=True) + + if self.repo_id is None: + return self.refresh_gradio() + + downloaded = snapshot_download( + repo_id=self.repo_id, + repo_type=self.repo_type, + revision=self.revision, + local_dir=self.hf_path, + force_download=True, + ) + + print(f'Downloaded: {downloaded}') + return self.refresh_gradio() + + def uninstall(self): + shutil.rmtree(self.hf_path) + print(f'Deleted: {self.hf_path}') + return self.refresh_gradio() + + def terminate(self): + self.is_running = False + while self.gradio_metas is not None: + time.sleep(0.1) + return self.refresh_gradio() + + def run(self): + self.is_running = True + Thread(target=self.gradio_worker).start() + while self.gradio_metas is None: + time.sleep(0.1) + return self.refresh_gradio() + + def gradio_worker(self): + memory_management.unload_all_models() + sys.path.insert(0, self.hf_path) + file_path = os.path.join(self.root_path, 'forge_app.py') + module_name = 'forge_space_' + str(uuid.uuid4()).replace('-', '_') + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + demo = getattr(module, 'demo') + + self.gradio_metas = demo.launch(inbrowser=True, prevent_thread_lock=True) + + while self.is_running: + time.sleep(0.1) + + demo.close() + self.gradio_metas = None + + if module_name in sys.modules: + del sys.modules[module_name] + + return + + +def main_entry(): + global spaces + + from modules.extensions import extensions + + tagged_extensions = {} + + for ex in extensions: + if ex.enabled and ex.is_forge_space: + tag = ex.space_meta['tag'] + + if tag not in tagged_extensions: + tagged_extensions[tag] = [] + + tagged_extensions[tag].append(ex) + + for tag, exs in tagged_extensions.items(): + with gr.Accordion(tag, open=True): + for ex in exs: + with gr.Row(equal_height=True): + space = ForgeSpace(root_path=ex.path, **ex.space_meta) + spaces.append(space) + + return diff --git a/modules_forge/initialization.py b/modules_forge/initialization.py index 800909f8..0420dea4 100644 --- a/modules_forge/initialization.py +++ b/modules_forge/initialization.py @@ -2,6 +2,7 @@ import os import sys +INITIALIZED = False MONITOR_MODEL_MOVING = False @@ -25,6 +26,13 @@ def monitor_module_moving(): def initialize_forge(): + global INITIALIZED + + if INITIALIZED: + return + + INITIALIZED = True + sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(__file__)), 'packages_3rdparty')) bad_list = ['--lowvram', '--medvram', '--medvram-sdxl'] @@ -60,9 +68,6 @@ def initialize_forge(): from modules_forge.bnb_installer import try_install_bnb try_install_bnb() - import modules_forge.patch_basic - modules_forge.patch_basic.patch_all_basics() - from backend import stream print('CUDA Using Stream:', stream.should_use_stream()) @@ -85,4 +90,8 @@ def initialize_forge(): if 'HF_HUB_CACHE' not in os.environ: os.environ['HF_HUB_CACHE'] = diffusers_dir + + import modules_forge.patch_basic + modules_forge.patch_basic.patch_all_basics() + return diff --git a/modules_forge/patch_basic.py b/modules_forge/patch_basic.py index 87048a20..d22bc003 100644 --- a/modules_forge/patch_basic.py +++ b/modules_forge/patch_basic.py @@ -6,6 +6,8 @@ import warnings import gradio.networking import safetensors.torch +from tqdm import tqdm + def gradio_url_ok_fix(url: str) -> bool: try: @@ -55,7 +57,20 @@ def build_loaded(module, loader_name): return +def always_show_tqdm(*args, **kwargs): + kwargs['disable'] = False + if 'name' in kwargs: + del kwargs['name'] + return tqdm(*args, **kwargs) + + def patch_all_basics(): + import logging + from huggingface_hub import file_download + file_download.tqdm = always_show_tqdm + from transformers.dynamic_module_utils import logger + logger.setLevel(logging.ERROR) + gradio.networking.url_ok = gradio_url_ok_fix build_loaded(safetensors.torch, 'load_file') build_loaded(torch, 'load') diff --git a/requirements_versions.txt b/requirements_versions.txt index 3c608adb..1d3f04cf 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -35,4 +35,6 @@ httpx==0.24.1 pillow-avif-plugin==1.4.3 diffusers==0.29.2 gradio_rangeslider==0.0.6 +gradio_imageslider==0.0.20 +loadimg==0.1.2 tqdm==4.66.1 diff --git a/spaces.py b/spaces.py new file mode 100644 index 00000000..bb9dfc6f --- /dev/null +++ b/spaces.py @@ -0,0 +1,79 @@ +from modules_forge.initialization import initialize_forge + +initialize_forge() + +import os +import torch +import inspect + +from backend import memory_management + + +gpu = memory_management.get_torch_device() + + +class GPUObject: + def __init__(self): + self.module_list = [] + + def __enter__(self): + self.original_init = torch.nn.Module.__init__ + self.original_to = torch.nn.Module.to + + def patched_init(module, *args, **kwargs): + self.module_list.append(module) + return self.original_init(module, *args, **kwargs) + + def patched_to(module, *args, **kwargs): + self.module_list.append(module) + return self.original_to(module, *args, **kwargs) + + torch.nn.Module.__init__ = patched_init + torch.nn.Module.to = patched_to + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + torch.nn.Module.__init__ = self.original_init + torch.nn.Module.to = self.original_to + self.module_list = set(self.module_list) + self.to(device=torch.device('cpu')) + memory_management.soft_empty_cache() + return + + def to(self, device): + for module in self.module_list: + module.to(device) + print(f'Forge Space: Moved {len(self.module_list)} Modules to {device}') + return self + + def gpu(self): + self.to(device=gpu) + return self + + +def GPU(gpu_objects=None, manual_load=False): + gpu_objects = gpu_objects or [] + + def decorator(func): + def wrapper(*args, **kwargs): + print("Entering Forge Space GPU ...") + memory_management.unload_all_models() + if not manual_load: + for o in gpu_objects: + o.gpu() + result = func(*args, **kwargs) + print("Cleaning Forge Space GPU ...") + for o in gpu_objects: + o.to(device=torch.device('cpu')) + memory_management.soft_empty_cache() + return result + return wrapper + return decorator + + +def convert_root_path(): + frame = inspect.currentframe().f_back + caller_file = frame.f_code.co_filename + caller_file = os.path.abspath(caller_file) + result = os.path.join(os.path.dirname(caller_file), 'huggingface_space_mirror') + return result + '/' diff --git a/style.css b/style.css index 9cc2534f..29e7418d 100644 --- a/style.css +++ b/style.css @@ -1673,3 +1673,17 @@ body.resizing .resize-handle { #quicksettings .gradio-slider span { padding-right: 5px; } + +.forge_space_label{ + padding: 10px; + min-width: 60% !important; + margin: 1px; + border-width: 1px; + border-radius: 8px; + border-style: solid; + border-color: #6f6f6f; +} + +.forge_space_btn{ + min-width: 0 !important; +} diff --git a/webui.py b/webui.py index 921edce9..ab0e6c2f 100644 --- a/webui.py +++ b/webui.py @@ -11,8 +11,6 @@ from modules_forge.initialization import initialize_forge from modules_forge import main_thread -from modules_forge.forge_canvas.canvas import canvas_js_root_path - startup_timer = timer.startup_timer startup_timer.record("launcher") @@ -83,6 +81,8 @@ def webui_worker(): elif shared.opts.auto_launch_browser == "Local": auto_launch_browser = not cmd_opts.webui_is_non_local + from modules_forge.forge_canvas.canvas import canvas_js_root_path + app, local_url, share_url = shared.demo.launch( share=cmd_opts.share, server_name=initialize_util.gradio_server_name(),