This commit is contained in:
layerdiffusion
2024-08-18 03:27:24 -07:00
parent 608af2e64c
commit 1e64709d41

View File

@@ -6,13 +6,30 @@ import os
import torch
import inspect
import functools
import gradio.oauth
import gradio.routes
from backend import memory_management
from diffusers.models import modeling_utils as diffusers_modeling_utils
from transformers import modeling_utils as transformers_modeling_utils
from backend.attention import AttentionProcessorForge
from starlette.requests import Request
_original_init = Request.__init__
def patched_init(self, scope, receive=None, send=None):
if 'session' not in scope:
scope['session'] = dict()
_original_init(self, scope, receive, send)
return
Request.__init__ = patched_init
gradio.oauth.attach_oauth = lambda x: None
gradio.routes.attach_oauth = lambda x: None
module_in_gpu: torch.nn.Module = None
gpu = memory_management.get_torch_device()
cpu = torch.device('cpu')
@@ -116,7 +133,10 @@ def convert_root_path():
return result + '/'
def automatically_move_to_gpu_when_forward(m: torch.nn.Module):
def automatically_move_to_gpu_when_forward(m: torch.nn.Module, target_model: torch.nn.Module = None):
if target_model is None:
target_model = m
def patch_method(method_name):
if not hasattr(m, method_name):
return
@@ -132,7 +152,7 @@ def automatically_move_to_gpu_when_forward(m: torch.nn.Module):
original_method = getattr(m, method_name)
def patched_method(*args, **kwargs):
load_module(m)
load_module(target_model)
return original_method(*args, **kwargs)
setattr(m, method_name, patched_method)