UNet from Scratch

Now backend rewrite is about 50% finished.
Estimated finish is in 72 hours.
After that, many newer features will land.
This commit is contained in:
layerdiffusion
2024-08-01 21:19:41 -07:00
parent e3522c8919
commit bc9977a305
20 changed files with 1393 additions and 56 deletions

View File

@@ -8,6 +8,7 @@ from ldm_patched.modules.sd import VAE, CLIP, load_model_weights
import ldm_patched.modules.model_patcher
import ldm_patched.modules.utils
import ldm_patched.modules.clip_vision
import backend.nn.unet
from omegaconf import OmegaConf
from modules.sd_models_config import find_checkpoint_config
@@ -19,6 +20,7 @@ from modules_forge import forge_clip
from modules_forge.unet_patcher import UnetPatcher
from ldm_patched.modules.model_base import model_sampling, ModelType
from backend.loader import load_huggingface_components
from backend.modules.k_model import KModel
import open_clip
from transformers import CLIPTextModel, CLIPTokenizer
@@ -85,27 +87,20 @@ def load_checkpoint_guess_config(sd, output_vae=True, output_clip=True, output_c
unet_dtype = model_management.unet_dtype(model_params=parameters)
load_device = model_management.get_torch_device()
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
manual_cast_dtype = unet_dtype if manual_cast_dtype is None else manual_cast_dtype
class WeightsLoader(torch.nn.Module):
pass
model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", unet_dtype)
model_config.set_manual_cast(manual_cast_dtype)
if model_config is None:
raise RuntimeError("ERROR: Could not detect model type")
if model_config.clip_vision_prefix is not None:
if output_clipvision:
clipvision = ldm_patched.modules.clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)
initial_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
backend.nn.unet.unet_initial_device = initial_load_device
backend.nn.unet.unet_initial_dtype = unet_dtype
huggingface_components = load_huggingface_components(sd)
if output_model:
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
offload_device = model_management.unet_offload_device()
model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device)
model.load_model_weights(sd, "model.diffusion_model.")
k_model = KModel(huggingface_components, storage_dtype=unet_dtype, computation_dtype=manual_cast_dtype)
k_model.to(device=initial_load_device, dtype=unet_dtype)
model_patcher = UnetPatcher(k_model, load_device=load_device,
offload_device=model_management.unet_offload_device(),
current_device=initial_load_device)
if output_vae:
vae = huggingface_components['vae']
@@ -118,12 +113,6 @@ def load_checkpoint_guess_config(sd, output_vae=True, output_clip=True, output_c
if len(left_over) > 0:
print("left over keys:", left_over)
if output_model:
model_patcher = UnetPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
if inital_load_device != torch.device("cpu"):
print("loaded straight to GPU")
model_management.load_model_gpu(model_patcher)
return ForgeSD(model_patcher, clip, vae, clipvision)
@@ -161,7 +150,7 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None):
timer.record("forge load real models")
sd_model.first_stage_model = forge_objects.vae.first_stage_model
sd_model.model.diffusion_model = forge_objects.unet.model.diffusion_model
sd_model.model.diffusion_model = forge_objects.unet.model
conditioner = getattr(sd_model, 'conditioner', None)
if conditioner:
@@ -202,8 +191,8 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None):
model_embeddings.token_embedding, sd_hijack.model_hijack)
sd_model.cond_stage_model = forge_clip.CLIP_SD_15_L(sd_model.cond_stage_model, sd_hijack.model_hijack)
elif type(sd_model.cond_stage_model).__name__ == 'FrozenOpenCLIPEmbedder': # SD21 Clip
sd_model.cond_stage_model.tokenizer = forge_objects.clip.tokenizer.clip_h
sd_model.cond_stage_model.transformer = forge_objects.clip.cond_stage_model.clip_h.transformer
sd_model.cond_stage_model.tokenizer = forge_objects.clip.tokenizer.clip_l
sd_model.cond_stage_model.transformer = forge_objects.clip.cond_stage_model.clip_l.transformer
model_embeddings = sd_model.cond_stage_model.transformer.text_model.embeddings
model_embeddings.token_embedding = sd_hijack.EmbeddingsWithFixes(
model_embeddings.token_embedding, sd_hijack.model_hijack)
@@ -216,9 +205,6 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None):
sd_model_hash = checkpoint_info.calculate_shorthash()
timer.record("calculate hash")
if getattr(sd_model, 'parameterization', None) == 'v':
sd_model.forge_objects.unet.model.model_sampling = model_sampling(sd_model.forge_objects.unet.model.model_config, ModelType.V_PREDICTION)
sd_model.is_sd3 = False
sd_model.latent_channels = 4
sd_model.is_sdxl = conditioner is not None
@@ -234,14 +220,14 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None):
@torch.inference_mode()
def patched_decode_first_stage(x):
sample = sd_model.forge_objects.unet.model.model_config.latent_format.process_out(x)
sample = sd_model.forge_objects.vae.first_stage_model.process_out(x)
sample = sd_model.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0
return sample.to(x)
@torch.inference_mode()
def patched_encode_first_stage(x):
sample = sd_model.forge_objects.vae.encode(x.movedim(1, -1) * 0.5 + 0.5)
sample = sd_model.forge_objects.unet.model.model_config.latent_format.process_in(sample)
sample = sd_model.forge_objects.vae.first_stage_model.process_in(sample)
return sample.to(x)
sd_model.ema_scope = lambda *args, **kwargs: contextlib.nullcontext()