mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-29 02:31:16 +00:00
gradio
This commit is contained in:
24
spaces.py
24
spaces.py
@@ -6,13 +6,30 @@ import os
|
|||||||
import torch
|
import torch
|
||||||
import inspect
|
import inspect
|
||||||
import functools
|
import functools
|
||||||
|
import gradio.oauth
|
||||||
|
import gradio.routes
|
||||||
|
|
||||||
from backend import memory_management
|
from backend import memory_management
|
||||||
from diffusers.models import modeling_utils as diffusers_modeling_utils
|
from diffusers.models import modeling_utils as diffusers_modeling_utils
|
||||||
from transformers import modeling_utils as transformers_modeling_utils
|
from transformers import modeling_utils as transformers_modeling_utils
|
||||||
from backend.attention import AttentionProcessorForge
|
from backend.attention import AttentionProcessorForge
|
||||||
|
from starlette.requests import Request
|
||||||
|
|
||||||
|
|
||||||
|
_original_init = Request.__init__
|
||||||
|
|
||||||
|
|
||||||
|
def patched_init(self, scope, receive=None, send=None):
|
||||||
|
if 'session' not in scope:
|
||||||
|
scope['session'] = dict()
|
||||||
|
_original_init(self, scope, receive, send)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
Request.__init__ = patched_init
|
||||||
|
gradio.oauth.attach_oauth = lambda x: None
|
||||||
|
gradio.routes.attach_oauth = lambda x: None
|
||||||
|
|
||||||
module_in_gpu: torch.nn.Module = None
|
module_in_gpu: torch.nn.Module = None
|
||||||
gpu = memory_management.get_torch_device()
|
gpu = memory_management.get_torch_device()
|
||||||
cpu = torch.device('cpu')
|
cpu = torch.device('cpu')
|
||||||
@@ -116,7 +133,10 @@ def convert_root_path():
|
|||||||
return result + '/'
|
return result + '/'
|
||||||
|
|
||||||
|
|
||||||
def automatically_move_to_gpu_when_forward(m: torch.nn.Module):
|
def automatically_move_to_gpu_when_forward(m: torch.nn.Module, target_model: torch.nn.Module = None):
|
||||||
|
if target_model is None:
|
||||||
|
target_model = m
|
||||||
|
|
||||||
def patch_method(method_name):
|
def patch_method(method_name):
|
||||||
if not hasattr(m, method_name):
|
if not hasattr(m, method_name):
|
||||||
return
|
return
|
||||||
@@ -132,7 +152,7 @@ def automatically_move_to_gpu_when_forward(m: torch.nn.Module):
|
|||||||
original_method = getattr(m, method_name)
|
original_method = getattr(m, method_name)
|
||||||
|
|
||||||
def patched_method(*args, **kwargs):
|
def patched_method(*args, **kwargs):
|
||||||
load_module(m)
|
load_module(target_model)
|
||||||
return original_method(*args, **kwargs)
|
return original_method(*args, **kwargs)
|
||||||
|
|
||||||
setattr(m, method_name, patched_method)
|
setattr(m, method_name, patched_method)
|
||||||
|
|||||||
Reference in New Issue
Block a user