diff --git a/extensions-builtin/forge_space_birefnet/forge_app.py b/extensions-builtin/forge_space_birefnet/forge_app.py index 6199819e..f2d9875c 100644 --- a/extensions-builtin/forge_space_birefnet/forge_app.py +++ b/extensions-builtin/forge_space_birefnet/forge_app.py @@ -7,7 +7,7 @@ from transformers import AutoModelForImageSegmentation import torch from torchvision import transforms -torch.set_float32_matmul_precision(["high", "highest"][0]) +# torch.set_float32_matmul_precision(["high", "highest"][0]) os.environ['HOME'] = spaces.convert_root_path() + 'home' @@ -16,6 +16,8 @@ with spaces.GPUObject() as birefnet_gpu_obj: "ZhengPeng7/BiRefNet", trust_remote_code=True ) +spaces.automatically_move_to_gpu_when_forward(birefnet) + transform_image = transforms.Compose( [ transforms.Resize((1024, 1024)), @@ -25,7 +27,7 @@ transform_image = transforms.Compose( ) -@spaces.GPU(gpu_objects=[birefnet_gpu_obj]) +@spaces.GPU(gpu_objects=[birefnet_gpu_obj], manual_load=True) def fn(image): im = load_img(image, output_type="pil") im = im.convert("RGB") diff --git a/spaces.py b/spaces.py index bb9dfc6f..a0a05c70 100644 --- a/spaces.py +++ b/spaces.py @@ -9,6 +9,7 @@ import inspect from backend import memory_management +module_in_gpu: torch.nn.Module = None gpu = memory_management.get_torch_device() @@ -56,6 +57,7 @@ def GPU(gpu_objects=None, manual_load=False): def decorator(func): def wrapper(*args, **kwargs): + global module_in_gpu print("Entering Forge Space GPU ...") memory_management.unload_all_models() if not manual_load: @@ -63,6 +65,9 @@ def GPU(gpu_objects=None, manual_load=False): o.gpu() result = func(*args, **kwargs) print("Cleaning Forge Space GPU ...") + if module_in_gpu is not None: + module_in_gpu.to(device=torch.device('cpu')) + module_in_gpu = None for o in gpu_objects: o.to(device=torch.device('cpu')) memory_management.soft_empty_cache() @@ -77,3 +82,20 @@ def convert_root_path(): caller_file = os.path.abspath(caller_file) result = os.path.join(os.path.dirname(caller_file), 'huggingface_space_mirror') return result + '/' + + +def automatically_move_to_gpu_when_forward(m: torch.nn.Module): + assert isinstance(m, torch.nn.Module), 'Cannot manage models other than torch Module!' + + def send_me_to_gpu(*args, **kwargs): + global module_in_gpu + if module_in_gpu is not None: + module_in_gpu.to(device=torch.device('cpu')) + module_in_gpu = None + memory_management.soft_empty_cache() + module_in_gpu = m + module_in_gpu.to(gpu) + return + + m.register_forward_pre_hook(send_me_to_gpu) + return