revise space logics

This commit is contained in:
layerdiffusion
2024-08-18 00:04:02 -07:00
parent 53cd00d125
commit 0ccbac5389
2 changed files with 26 additions and 2 deletions

View File

@@ -7,7 +7,7 @@ from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
torch.set_float32_matmul_precision(["high", "highest"][0])
# torch.set_float32_matmul_precision(["high", "highest"][0])
os.environ['HOME'] = spaces.convert_root_path() + 'home'
@@ -16,6 +16,8 @@ with spaces.GPUObject() as birefnet_gpu_obj:
"ZhengPeng7/BiRefNet", trust_remote_code=True
)
spaces.automatically_move_to_gpu_when_forward(birefnet)
transform_image = transforms.Compose(
[
transforms.Resize((1024, 1024)),
@@ -25,7 +27,7 @@ transform_image = transforms.Compose(
)
@spaces.GPU(gpu_objects=[birefnet_gpu_obj])
@spaces.GPU(gpu_objects=[birefnet_gpu_obj], manual_load=True)
def fn(image):
im = load_img(image, output_type="pil")
im = im.convert("RGB")

View File

@@ -9,6 +9,7 @@ import inspect
from backend import memory_management
module_in_gpu: torch.nn.Module = None
gpu = memory_management.get_torch_device()
@@ -56,6 +57,7 @@ def GPU(gpu_objects=None, manual_load=False):
def decorator(func):
def wrapper(*args, **kwargs):
global module_in_gpu
print("Entering Forge Space GPU ...")
memory_management.unload_all_models()
if not manual_load:
@@ -63,6 +65,9 @@ def GPU(gpu_objects=None, manual_load=False):
o.gpu()
result = func(*args, **kwargs)
print("Cleaning Forge Space GPU ...")
if module_in_gpu is not None:
module_in_gpu.to(device=torch.device('cpu'))
module_in_gpu = None
for o in gpu_objects:
o.to(device=torch.device('cpu'))
memory_management.soft_empty_cache()
@@ -77,3 +82,20 @@ def convert_root_path():
caller_file = os.path.abspath(caller_file)
result = os.path.join(os.path.dirname(caller_file), 'huggingface_space_mirror')
return result + '/'
def automatically_move_to_gpu_when_forward(m: torch.nn.Module):
assert isinstance(m, torch.nn.Module), 'Cannot manage models other than torch Module!'
def send_me_to_gpu(*args, **kwargs):
global module_in_gpu
if module_in_gpu is not None:
module_in_gpu.to(device=torch.device('cpu'))
module_in_gpu = None
memory_management.soft_empty_cache()
module_in_gpu = m
module_in_gpu.to(gpu)
return
m.register_forward_pre_hook(send_me_to_gpu)
return