revise space logics

This commit is contained in:
layerdiffusion
2024-08-18 01:44:29 -07:00
parent 72ab92f83e
commit 101b556ee5
2 changed files with 71 additions and 14 deletions

View File

@@ -38,3 +38,4 @@ gradio_rangeslider==0.0.6
gradio_imageslider==0.0.20
loadimg==0.1.2
tqdm==4.66.1
peft==0.12.0

View File

@@ -7,10 +7,43 @@ import torch
import inspect
from backend import memory_management
from diffusers.models import modeling_utils as diffusers_modeling_utils
from transformers import modeling_utils as transformers_modeling_utils
from backend.attention import AttentionProcessorForge
module_in_gpu: torch.nn.Module = None
gpu = memory_management.get_torch_device()
cpu = torch.device('cpu')
diffusers_modeling_utils.get_parameter_device = lambda *args, **kwargs: gpu
transformers_modeling_utils.get_parameter_device = lambda *args, **kwargs: gpu
def unload_module():
global module_in_gpu
if module_in_gpu is None:
return
print(f'Moved module to CPU: {type(module_in_gpu).__name__}')
module_in_gpu.to(cpu)
module_in_gpu = None
memory_management.soft_empty_cache()
return
def load_module(m):
global module_in_gpu
if module_in_gpu == m:
return
unload_module()
module_in_gpu = m
module_in_gpu.to(gpu)
print(f'Moved module to GPU: {type(module_in_gpu).__name__}')
return
class GPUObject:
@@ -57,7 +90,6 @@ def GPU(gpu_objects=None, manual_load=False):
def decorator(func):
def wrapper(*args, **kwargs):
global module_in_gpu
print("Entering Forge Space GPU ...")
memory_management.unload_all_models()
if not manual_load:
@@ -65,9 +97,7 @@ def GPU(gpu_objects=None, manual_load=False):
o.gpu()
result = func(*args, **kwargs)
print("Cleaning Forge Space GPU ...")
if module_in_gpu is not None:
module_in_gpu.to(device=torch.device('cpu'))
module_in_gpu = None
unload_module()
for o in gpu_objects:
o.to(device=torch.device('cpu'))
memory_management.soft_empty_cache()
@@ -85,17 +115,43 @@ def convert_root_path():
def automatically_move_to_gpu_when_forward(m: torch.nn.Module):
assert isinstance(m, torch.nn.Module), 'Cannot manage models other than torch Module!'
def patch_method(method_name):
if not hasattr(m, method_name):
return
def send_me_to_gpu(*args, **kwargs):
global module_in_gpu
if module_in_gpu is not None:
module_in_gpu.to(device=torch.device('cpu'))
module_in_gpu = None
memory_management.soft_empty_cache()
module_in_gpu = m
module_in_gpu.to(gpu)
if not hasattr(m, 'forge_space_hooked_names'):
m.forge_space_hooked_names = []
if method_name in m.forge_space_hooked_names:
return
print(f'Automatic hook: {type(m).__name__}.{method_name}')
original_method = getattr(m, method_name)
def patched_method(*args, **kwargs):
load_module(m)
return original_method(*args, **kwargs)
setattr(m, method_name, patched_method)
m.forge_space_hooked_names.append(method_name)
return
m.register_forward_pre_hook(send_me_to_gpu)
for method_name in ['forward', 'encode', 'decode']:
patch_method(method_name)
return
def automatically_move_pipeline_components(pipe):
for attr_name in dir(pipe):
attr_value = getattr(pipe, attr_name, None)
if isinstance(attr_value, torch.nn.Module):
automatically_move_to_gpu_when_forward(attr_value)
return
def change_attention_from_diffusers_to_forge(m):
m.set_attn_processor(AttentionProcessorForge())
return