mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-03-03 12:09:51 +00:00
UNet from Scratch
Now backend rewrite is about 50% finished. Estimated finish is in 72 hours. After that, many newer features will land.
This commit is contained in:
@@ -8,6 +8,7 @@ from ldm_patched.modules.sd import VAE, CLIP, load_model_weights
|
||||
import ldm_patched.modules.model_patcher
|
||||
import ldm_patched.modules.utils
|
||||
import ldm_patched.modules.clip_vision
|
||||
import backend.nn.unet
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
from modules.sd_models_config import find_checkpoint_config
|
||||
@@ -19,6 +20,7 @@ from modules_forge import forge_clip
|
||||
from modules_forge.unet_patcher import UnetPatcher
|
||||
from ldm_patched.modules.model_base import model_sampling, ModelType
|
||||
from backend.loader import load_huggingface_components
|
||||
from backend.modules.k_model import KModel
|
||||
|
||||
import open_clip
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
@@ -85,27 +87,20 @@ def load_checkpoint_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters)
|
||||
load_device = model_management.get_torch_device()
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
|
||||
manual_cast_dtype = unet_dtype if manual_cast_dtype is None else manual_cast_dtype
|
||||
|
||||
class WeightsLoader(torch.nn.Module):
|
||||
pass
|
||||
|
||||
model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", unet_dtype)
|
||||
model_config.set_manual_cast(manual_cast_dtype)
|
||||
|
||||
if model_config is None:
|
||||
raise RuntimeError("ERROR: Could not detect model type")
|
||||
|
||||
if model_config.clip_vision_prefix is not None:
|
||||
if output_clipvision:
|
||||
clipvision = ldm_patched.modules.clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)
|
||||
initial_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
|
||||
backend.nn.unet.unet_initial_device = initial_load_device
|
||||
backend.nn.unet.unet_initial_dtype = unet_dtype
|
||||
|
||||
huggingface_components = load_huggingface_components(sd)
|
||||
|
||||
if output_model:
|
||||
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
|
||||
offload_device = model_management.unet_offload_device()
|
||||
model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device)
|
||||
model.load_model_weights(sd, "model.diffusion_model.")
|
||||
k_model = KModel(huggingface_components, storage_dtype=unet_dtype, computation_dtype=manual_cast_dtype)
|
||||
k_model.to(device=initial_load_device, dtype=unet_dtype)
|
||||
model_patcher = UnetPatcher(k_model, load_device=load_device,
|
||||
offload_device=model_management.unet_offload_device(),
|
||||
current_device=initial_load_device)
|
||||
|
||||
if output_vae:
|
||||
vae = huggingface_components['vae']
|
||||
@@ -118,12 +113,6 @@ def load_checkpoint_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
if len(left_over) > 0:
|
||||
print("left over keys:", left_over)
|
||||
|
||||
if output_model:
|
||||
model_patcher = UnetPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
|
||||
if inital_load_device != torch.device("cpu"):
|
||||
print("loaded straight to GPU")
|
||||
model_management.load_model_gpu(model_patcher)
|
||||
|
||||
return ForgeSD(model_patcher, clip, vae, clipvision)
|
||||
|
||||
|
||||
@@ -161,7 +150,7 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None):
|
||||
timer.record("forge load real models")
|
||||
|
||||
sd_model.first_stage_model = forge_objects.vae.first_stage_model
|
||||
sd_model.model.diffusion_model = forge_objects.unet.model.diffusion_model
|
||||
sd_model.model.diffusion_model = forge_objects.unet.model
|
||||
|
||||
conditioner = getattr(sd_model, 'conditioner', None)
|
||||
if conditioner:
|
||||
@@ -202,8 +191,8 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None):
|
||||
model_embeddings.token_embedding, sd_hijack.model_hijack)
|
||||
sd_model.cond_stage_model = forge_clip.CLIP_SD_15_L(sd_model.cond_stage_model, sd_hijack.model_hijack)
|
||||
elif type(sd_model.cond_stage_model).__name__ == 'FrozenOpenCLIPEmbedder': # SD21 Clip
|
||||
sd_model.cond_stage_model.tokenizer = forge_objects.clip.tokenizer.clip_h
|
||||
sd_model.cond_stage_model.transformer = forge_objects.clip.cond_stage_model.clip_h.transformer
|
||||
sd_model.cond_stage_model.tokenizer = forge_objects.clip.tokenizer.clip_l
|
||||
sd_model.cond_stage_model.transformer = forge_objects.clip.cond_stage_model.clip_l.transformer
|
||||
model_embeddings = sd_model.cond_stage_model.transformer.text_model.embeddings
|
||||
model_embeddings.token_embedding = sd_hijack.EmbeddingsWithFixes(
|
||||
model_embeddings.token_embedding, sd_hijack.model_hijack)
|
||||
@@ -216,9 +205,6 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None):
|
||||
sd_model_hash = checkpoint_info.calculate_shorthash()
|
||||
timer.record("calculate hash")
|
||||
|
||||
if getattr(sd_model, 'parameterization', None) == 'v':
|
||||
sd_model.forge_objects.unet.model.model_sampling = model_sampling(sd_model.forge_objects.unet.model.model_config, ModelType.V_PREDICTION)
|
||||
|
||||
sd_model.is_sd3 = False
|
||||
sd_model.latent_channels = 4
|
||||
sd_model.is_sdxl = conditioner is not None
|
||||
@@ -234,14 +220,14 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None):
|
||||
|
||||
@torch.inference_mode()
|
||||
def patched_decode_first_stage(x):
|
||||
sample = sd_model.forge_objects.unet.model.model_config.latent_format.process_out(x)
|
||||
sample = sd_model.forge_objects.vae.first_stage_model.process_out(x)
|
||||
sample = sd_model.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0
|
||||
return sample.to(x)
|
||||
|
||||
@torch.inference_mode()
|
||||
def patched_encode_first_stage(x):
|
||||
sample = sd_model.forge_objects.vae.encode(x.movedim(1, -1) * 0.5 + 0.5)
|
||||
sample = sd_model.forge_objects.unet.model.model_config.latent_format.process_in(sample)
|
||||
sample = sd_model.forge_objects.vae.first_stage_model.process_in(sample)
|
||||
return sample.to(x)
|
||||
|
||||
sd_model.ema_scope = lambda *args, **kwargs: contextlib.nullcontext()
|
||||
|
||||
Reference in New Issue
Block a user