mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-05-01 03:31:30 +00:00
rework sd1.5 and sdxl from scratch
This commit is contained in:
@@ -1,24 +1,10 @@
|
||||
import torch
|
||||
|
||||
from backend import memory_management
|
||||
from backend.patcher.base import ModelPatcher
|
||||
|
||||
|
||||
class JointTokenizer:
|
||||
def __init__(self, huggingface_components):
|
||||
self.clip_l = huggingface_components.get('tokenizer', None)
|
||||
self.clip_g = huggingface_components.get('tokenizer_2', None)
|
||||
|
||||
|
||||
class JointCLIPTextEncoder(torch.nn.Module):
|
||||
def __init__(self, huggingface_components):
|
||||
super().__init__()
|
||||
self.clip_l = huggingface_components.get('text_encoder', None)
|
||||
self.clip_g = huggingface_components.get('text_encoder_2', None)
|
||||
from backend.nn.base import ModuleDict, ObjectDict
|
||||
|
||||
|
||||
class CLIP:
|
||||
def __init__(self, huggingface_components=None, no_init=False):
|
||||
def __init__(self, model_dict={}, tokenizer_dict={}, no_init=False):
|
||||
if no_init:
|
||||
return
|
||||
|
||||
@@ -26,8 +12,8 @@ class CLIP:
|
||||
offload_device = memory_management.text_encoder_offload_device()
|
||||
text_encoder_dtype = memory_management.text_encoder_dtype(load_device)
|
||||
|
||||
self.cond_stage_model = JointCLIPTextEncoder(huggingface_components)
|
||||
self.tokenizer = JointTokenizer(huggingface_components)
|
||||
self.cond_stage_model = ModuleDict(model_dict)
|
||||
self.tokenizer = ObjectDict(tokenizer_dict)
|
||||
self.cond_stage_model.to(dtype=text_encoder_dtype, device=offload_device)
|
||||
self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||
|
||||
|
||||
@@ -1,12 +1,26 @@
|
||||
import copy
|
||||
import torch
|
||||
|
||||
from backend.modules.k_model import KModel
|
||||
from backend.patcher.base import ModelPatcher
|
||||
from backend import memory_management
|
||||
|
||||
|
||||
class UnetPatcher(ModelPatcher):
|
||||
def __init__(self, model, *args, **kwargs):
|
||||
super().__init__(model, *args, **kwargs)
|
||||
@classmethod
|
||||
def from_model(cls, model, diffusers_scheduler):
|
||||
parameters = memory_management.module_size(model)
|
||||
unet_dtype = memory_management.unet_dtype(model_params=parameters)
|
||||
load_device = memory_management.get_torch_device()
|
||||
initial_load_device = memory_management.unet_inital_load_device(parameters, unet_dtype)
|
||||
manual_cast_dtype = memory_management.unet_manual_cast(unet_dtype, load_device)
|
||||
manual_cast_dtype = unet_dtype if manual_cast_dtype is None else manual_cast_dtype
|
||||
model.to(device=initial_load_device, dtype=unet_dtype)
|
||||
model = KModel(model=model, diffusers_scheduler=diffusers_scheduler, storage_dtype=unet_dtype, computation_dtype=manual_cast_dtype)
|
||||
return UnetPatcher(model, load_device=load_device, offload_device=memory_management.unet_offload_device(), current_device=initial_load_device)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.controlnet_linked_list = None
|
||||
self.extra_preserved_memory_during_sampling = 0
|
||||
self.extra_model_patchers_during_sampling = []
|
||||
|
||||
Reference in New Issue
Block a user