mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-01-26 19:09:45 +00:00
gradio
This commit is contained in:
24
spaces.py
24
spaces.py
@@ -6,13 +6,30 @@ import os
|
||||
import torch
|
||||
import inspect
|
||||
import functools
|
||||
import gradio.oauth
|
||||
import gradio.routes
|
||||
|
||||
from backend import memory_management
|
||||
from diffusers.models import modeling_utils as diffusers_modeling_utils
|
||||
from transformers import modeling_utils as transformers_modeling_utils
|
||||
from backend.attention import AttentionProcessorForge
|
||||
from starlette.requests import Request
|
||||
|
||||
|
||||
_original_init = Request.__init__
|
||||
|
||||
|
||||
def patched_init(self, scope, receive=None, send=None):
|
||||
if 'session' not in scope:
|
||||
scope['session'] = dict()
|
||||
_original_init(self, scope, receive, send)
|
||||
return
|
||||
|
||||
|
||||
Request.__init__ = patched_init
|
||||
gradio.oauth.attach_oauth = lambda x: None
|
||||
gradio.routes.attach_oauth = lambda x: None
|
||||
|
||||
module_in_gpu: torch.nn.Module = None
|
||||
gpu = memory_management.get_torch_device()
|
||||
cpu = torch.device('cpu')
|
||||
@@ -116,7 +133,10 @@ def convert_root_path():
|
||||
return result + '/'
|
||||
|
||||
|
||||
def automatically_move_to_gpu_when_forward(m: torch.nn.Module):
|
||||
def automatically_move_to_gpu_when_forward(m: torch.nn.Module, target_model: torch.nn.Module = None):
|
||||
if target_model is None:
|
||||
target_model = m
|
||||
|
||||
def patch_method(method_name):
|
||||
if not hasattr(m, method_name):
|
||||
return
|
||||
@@ -132,7 +152,7 @@ def automatically_move_to_gpu_when_forward(m: torch.nn.Module):
|
||||
original_method = getattr(m, method_name)
|
||||
|
||||
def patched_method(*args, **kwargs):
|
||||
load_module(m)
|
||||
load_module(target_model)
|
||||
return original_method(*args, **kwargs)
|
||||
|
||||
setattr(m, method_name, patched_method)
|
||||
|
||||
Reference in New Issue
Block a user