mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-21 23:09:15 +00:00
Added a converter back to ldm from diffusers for sdxl. Can finally get to training it properly
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
import gc
|
||||
import typing
|
||||
from typing import Union, OrderedDict, List, Tuple
|
||||
from typing import Union, List, Tuple
|
||||
import sys
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
|
||||
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
|
||||
from safetensors.torch import save_file
|
||||
@@ -10,11 +11,12 @@ from tqdm import tqdm
|
||||
from torchvision.transforms import Resize
|
||||
|
||||
from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \
|
||||
convert_vae_state_dict
|
||||
convert_vae_state_dict, load_vae
|
||||
from toolkit import train_tools
|
||||
from toolkit.config_modules import ModelConfig, GenerateImageConfig
|
||||
from toolkit.metadata import get_meta_for_safetensors
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
from toolkit.saving import save_ldm_model_from_diffusers
|
||||
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
||||
import torch
|
||||
from library import model_util
|
||||
@@ -27,6 +29,13 @@ import diffusers
|
||||
# tell it to shut up
|
||||
diffusers.logging.set_verbosity(diffusers.logging.ERROR)
|
||||
|
||||
VAE_PREFIX_UNET = "vae"
|
||||
SD_PREFIX_UNET = "unet"
|
||||
SD_PREFIX_TEXT_ENCODER = "te"
|
||||
|
||||
SD_PREFIX_TEXT_ENCODER1 = "te1"
|
||||
SD_PREFIX_TEXT_ENCODER2 = "te2"
|
||||
|
||||
|
||||
class BlankNetwork:
|
||||
multiplier = 1.0
|
||||
@@ -218,6 +227,10 @@ class StableDiffusion:
|
||||
# scheduler doesn't get set sometimes, so we set it here
|
||||
pipe.scheduler = scheduler
|
||||
|
||||
if self.model_config.vae_path is not None:
|
||||
external_vae = load_vae(self.model_config.vae_path, dtype)
|
||||
pipe.vae = external_vae
|
||||
|
||||
self.unet = pipe.unet
|
||||
self.noise_scheduler = pipe.scheduler
|
||||
self.vae = pipe.vae.to(self.device_torch, dtype=dtype)
|
||||
@@ -630,8 +643,33 @@ class StableDiffusion:
|
||||
|
||||
raise ValueError(f"Unknown weight name: {name}")
|
||||
|
||||
def state_dict(self, vae=True, text_encoder=True, unet=True):
|
||||
state_dict = OrderedDict()
|
||||
if vae:
|
||||
for k, v in self.vae.state_dict().items():
|
||||
new_key = k if k.startswith(f"{VAE_PREFIX_UNET}") else f"{VAE_PREFIX_UNET}_{k}"
|
||||
state_dict[new_key] = v
|
||||
if text_encoder:
|
||||
if isinstance(self.text_encoder, list):
|
||||
for i, encoder in enumerate(self.text_encoder):
|
||||
for k, v in encoder.state_dict().items():
|
||||
new_key = k if k.startswith(
|
||||
f"{SD_PREFIX_TEXT_ENCODER}{i}") else f"{SD_PREFIX_TEXT_ENCODER}{i}_{k}"
|
||||
state_dict[new_key] = v
|
||||
else:
|
||||
for k, v in self.text_encoder.state_dict().items():
|
||||
new_key = k if k.startswith(f"{SD_PREFIX_TEXT_ENCODER}") else f"{SD_PREFIX_TEXT_ENCODER}_{k}"
|
||||
state_dict[new_key] = v
|
||||
if unet:
|
||||
for k, v in self.unet.state_dict().items():
|
||||
new_key = k if k.startswith(f"{SD_PREFIX_UNET}") else f"{SD_PREFIX_UNET}_{k}"
|
||||
state_dict[new_key] = v
|
||||
return state_dict
|
||||
|
||||
def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None):
|
||||
state_dict = {}
|
||||
# prepare metadata
|
||||
meta = get_meta_for_safetensors(meta)
|
||||
|
||||
def update_sd(prefix, sd):
|
||||
for k, v in sd.items():
|
||||
@@ -644,14 +682,13 @@ class StableDiffusion:
|
||||
|
||||
# todo see what logit scale is
|
||||
if self.is_xl:
|
||||
# Convert the UNet model
|
||||
update_sd("model.diffusion_model.", self.unet.state_dict())
|
||||
|
||||
# Convert the text encoders
|
||||
update_sd("conditioner.embedders.0.transformer.", self.text_encoder[0].state_dict())
|
||||
|
||||
text_enc2_dict = convert_text_encoder_2_state_dict_to_sdxl(self.text_encoder[1].state_dict(), logit_scale)
|
||||
update_sd("conditioner.embedders.1.model.", text_enc2_dict)
|
||||
save_ldm_model_from_diffusers(
|
||||
sd=self,
|
||||
output_file=output_file,
|
||||
meta=meta,
|
||||
save_dtype=save_dtype,
|
||||
sd_version='sdxl',
|
||||
)
|
||||
|
||||
else:
|
||||
# Convert the UNet model
|
||||
@@ -667,13 +704,11 @@ class StableDiffusion:
|
||||
text_enc_dict = self.text_encoder.state_dict()
|
||||
update_sd("cond_stage_model.transformer.", text_enc_dict)
|
||||
|
||||
# Convert the VAE
|
||||
if self.vae is not None:
|
||||
vae_dict = model_util.convert_vae_state_dict(self.vae.state_dict())
|
||||
update_sd("first_stage_model.", vae_dict)
|
||||
# Convert the VAE
|
||||
if self.vae is not None:
|
||||
vae_dict = model_util.convert_vae_state_dict(self.vae.state_dict())
|
||||
update_sd("first_stage_model.", vae_dict)
|
||||
|
||||
# prepare metadata
|
||||
meta = get_meta_for_safetensors(meta)
|
||||
# make sure parent folder exists
|
||||
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||
save_file(state_dict, output_file, metadata=meta)
|
||||
# make sure parent folder exists
|
||||
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||
save_file(state_dict, output_file, metadata=meta)
|
||||
|
||||
Reference in New Issue
Block a user