mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-30 11:11:15 +00:00
Update diffusers_patcher.py
This commit is contained in:
@@ -1,13 +1,14 @@
|
|||||||
import diffusers
|
|
||||||
import torch
|
import torch
|
||||||
import ldm_patched.modules.ops as ops
|
import ldm_patched.modules.ops as ops
|
||||||
|
|
||||||
|
from diffusers.models.attention_processor import AttnProcessor2_0
|
||||||
from ldm_patched.modules.model_patcher import ModelPatcher
|
from ldm_patched.modules.model_patcher import ModelPatcher
|
||||||
from ldm_patched.modules import model_management
|
from ldm_patched.modules import model_management
|
||||||
from modules_forge.ops import use_patched_ops
|
from modules_forge.ops import use_patched_ops
|
||||||
from transformers import modeling_utils
|
from transformers import modeling_utils
|
||||||
|
|
||||||
|
|
||||||
class DiffusersPatcher:
|
class DiffusersModelPatcher:
|
||||||
def __init__(self, pipeline_class, dtype=torch.float16, *args, **kwargs):
|
def __init__(self, pipeline_class, dtype=torch.float16, *args, **kwargs):
|
||||||
load_device = model_management.get_torch_device()
|
load_device = model_management.get_torch_device()
|
||||||
offload_device = torch.device("cpu")
|
offload_device = torch.device("cpu")
|
||||||
@@ -21,6 +22,10 @@ class DiffusersPatcher:
|
|||||||
with modeling_utils.no_init_weights():
|
with modeling_utils.no_init_weights():
|
||||||
self.pipeline = pipeline_class.from_pretrained(*args, **kwargs)
|
self.pipeline = pipeline_class.from_pretrained(*args, **kwargs)
|
||||||
|
|
||||||
|
if hasattr(self.pipeline, 'unet'):
|
||||||
|
self.pipeline.unet.set_attn_processor(AttnProcessor2_0())
|
||||||
|
print('Attention optimization applied to DiffusersModelPatcher')
|
||||||
|
|
||||||
self.pipeline = self.pipeline.to(device=offload_device, dtype=dtype)
|
self.pipeline = self.pipeline.to(device=offload_device, dtype=dtype)
|
||||||
self.pipeline.eval()
|
self.pipeline.eval()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user