mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-03-03 12:09:51 +00:00
revise space logics
This commit is contained in:
@@ -7,7 +7,7 @@ from transformers import AutoModelForImageSegmentation
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
|
||||
torch.set_float32_matmul_precision(["high", "highest"][0])
|
||||
# torch.set_float32_matmul_precision(["high", "highest"][0])
|
||||
|
||||
os.environ['HOME'] = spaces.convert_root_path() + 'home'
|
||||
|
||||
@@ -16,6 +16,8 @@ with spaces.GPUObject() as birefnet_gpu_obj:
|
||||
"ZhengPeng7/BiRefNet", trust_remote_code=True
|
||||
)
|
||||
|
||||
spaces.automatically_move_to_gpu_when_forward(birefnet)
|
||||
|
||||
transform_image = transforms.Compose(
|
||||
[
|
||||
transforms.Resize((1024, 1024)),
|
||||
@@ -25,7 +27,7 @@ transform_image = transforms.Compose(
|
||||
)
|
||||
|
||||
|
||||
@spaces.GPU(gpu_objects=[birefnet_gpu_obj])
|
||||
@spaces.GPU(gpu_objects=[birefnet_gpu_obj], manual_load=True)
|
||||
def fn(image):
|
||||
im = load_img(image, output_type="pil")
|
||||
im = im.convert("RGB")
|
||||
|
||||
22
spaces.py
22
spaces.py
@@ -9,6 +9,7 @@ import inspect
|
||||
from backend import memory_management
|
||||
|
||||
|
||||
module_in_gpu: torch.nn.Module = None
|
||||
gpu = memory_management.get_torch_device()
|
||||
|
||||
|
||||
@@ -56,6 +57,7 @@ def GPU(gpu_objects=None, manual_load=False):
|
||||
|
||||
def decorator(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
global module_in_gpu
|
||||
print("Entering Forge Space GPU ...")
|
||||
memory_management.unload_all_models()
|
||||
if not manual_load:
|
||||
@@ -63,6 +65,9 @@ def GPU(gpu_objects=None, manual_load=False):
|
||||
o.gpu()
|
||||
result = func(*args, **kwargs)
|
||||
print("Cleaning Forge Space GPU ...")
|
||||
if module_in_gpu is not None:
|
||||
module_in_gpu.to(device=torch.device('cpu'))
|
||||
module_in_gpu = None
|
||||
for o in gpu_objects:
|
||||
o.to(device=torch.device('cpu'))
|
||||
memory_management.soft_empty_cache()
|
||||
@@ -77,3 +82,20 @@ def convert_root_path():
|
||||
caller_file = os.path.abspath(caller_file)
|
||||
result = os.path.join(os.path.dirname(caller_file), 'huggingface_space_mirror')
|
||||
return result + '/'
|
||||
|
||||
|
||||
def automatically_move_to_gpu_when_forward(m: torch.nn.Module):
|
||||
assert isinstance(m, torch.nn.Module), 'Cannot manage models other than torch Module!'
|
||||
|
||||
def send_me_to_gpu(*args, **kwargs):
|
||||
global module_in_gpu
|
||||
if module_in_gpu is not None:
|
||||
module_in_gpu.to(device=torch.device('cpu'))
|
||||
module_in_gpu = None
|
||||
memory_management.soft_empty_cache()
|
||||
module_in_gpu = m
|
||||
module_in_gpu.to(gpu)
|
||||
return
|
||||
|
||||
m.register_forward_pre_hook(send_me_to_gpu)
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user