mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-02-03 14:54:23 +00:00
revise space logics
This commit is contained in:
@@ -38,3 +38,4 @@ gradio_rangeslider==0.0.6
|
||||
gradio_imageslider==0.0.20
|
||||
loadimg==0.1.2
|
||||
tqdm==4.66.1
|
||||
peft==0.12.0
|
||||
|
||||
84
spaces.py
84
spaces.py
@@ -7,10 +7,43 @@ import torch
|
||||
import inspect
|
||||
|
||||
from backend import memory_management
|
||||
from diffusers.models import modeling_utils as diffusers_modeling_utils
|
||||
from transformers import modeling_utils as transformers_modeling_utils
|
||||
from backend.attention import AttentionProcessorForge
|
||||
|
||||
|
||||
module_in_gpu: torch.nn.Module = None
|
||||
gpu = memory_management.get_torch_device()
|
||||
cpu = torch.device('cpu')
|
||||
|
||||
diffusers_modeling_utils.get_parameter_device = lambda *args, **kwargs: gpu
|
||||
transformers_modeling_utils.get_parameter_device = lambda *args, **kwargs: gpu
|
||||
|
||||
|
||||
def unload_module():
|
||||
global module_in_gpu
|
||||
|
||||
if module_in_gpu is None:
|
||||
return
|
||||
|
||||
print(f'Moved module to CPU: {type(module_in_gpu).__name__}')
|
||||
module_in_gpu.to(cpu)
|
||||
module_in_gpu = None
|
||||
memory_management.soft_empty_cache()
|
||||
return
|
||||
|
||||
|
||||
def load_module(m):
|
||||
global module_in_gpu
|
||||
|
||||
if module_in_gpu == m:
|
||||
return
|
||||
|
||||
unload_module()
|
||||
module_in_gpu = m
|
||||
module_in_gpu.to(gpu)
|
||||
print(f'Moved module to GPU: {type(module_in_gpu).__name__}')
|
||||
return
|
||||
|
||||
|
||||
class GPUObject:
|
||||
@@ -57,7 +90,6 @@ def GPU(gpu_objects=None, manual_load=False):
|
||||
|
||||
def decorator(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
global module_in_gpu
|
||||
print("Entering Forge Space GPU ...")
|
||||
memory_management.unload_all_models()
|
||||
if not manual_load:
|
||||
@@ -65,9 +97,7 @@ def GPU(gpu_objects=None, manual_load=False):
|
||||
o.gpu()
|
||||
result = func(*args, **kwargs)
|
||||
print("Cleaning Forge Space GPU ...")
|
||||
if module_in_gpu is not None:
|
||||
module_in_gpu.to(device=torch.device('cpu'))
|
||||
module_in_gpu = None
|
||||
unload_module()
|
||||
for o in gpu_objects:
|
||||
o.to(device=torch.device('cpu'))
|
||||
memory_management.soft_empty_cache()
|
||||
@@ -85,17 +115,43 @@ def convert_root_path():
|
||||
|
||||
|
||||
def automatically_move_to_gpu_when_forward(m: torch.nn.Module):
|
||||
assert isinstance(m, torch.nn.Module), 'Cannot manage models other than torch Module!'
|
||||
def patch_method(method_name):
|
||||
if not hasattr(m, method_name):
|
||||
return
|
||||
|
||||
def send_me_to_gpu(*args, **kwargs):
|
||||
global module_in_gpu
|
||||
if module_in_gpu is not None:
|
||||
module_in_gpu.to(device=torch.device('cpu'))
|
||||
module_in_gpu = None
|
||||
memory_management.soft_empty_cache()
|
||||
module_in_gpu = m
|
||||
module_in_gpu.to(gpu)
|
||||
if not hasattr(m, 'forge_space_hooked_names'):
|
||||
m.forge_space_hooked_names = []
|
||||
|
||||
if method_name in m.forge_space_hooked_names:
|
||||
return
|
||||
|
||||
print(f'Automatic hook: {type(m).__name__}.{method_name}')
|
||||
|
||||
original_method = getattr(m, method_name)
|
||||
|
||||
def patched_method(*args, **kwargs):
|
||||
load_module(m)
|
||||
return original_method(*args, **kwargs)
|
||||
|
||||
setattr(m, method_name, patched_method)
|
||||
|
||||
m.forge_space_hooked_names.append(method_name)
|
||||
return
|
||||
|
||||
m.register_forward_pre_hook(send_me_to_gpu)
|
||||
for method_name in ['forward', 'encode', 'decode']:
|
||||
patch_method(method_name)
|
||||
|
||||
return
|
||||
|
||||
|
||||
def automatically_move_pipeline_components(pipe):
|
||||
for attr_name in dir(pipe):
|
||||
attr_value = getattr(pipe, attr_name, None)
|
||||
if isinstance(attr_value, torch.nn.Module):
|
||||
automatically_move_to_gpu_when_forward(attr_value)
|
||||
return
|
||||
|
||||
|
||||
def change_attention_from_diffusers_to_forge(m):
|
||||
m.set_attn_processor(AttentionProcessorForge())
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user