Added a converter back to ldm from diffusers for sdxl. Can finally get to training it properly

This commit is contained in:
Jaret Burkett
2023-08-21 16:22:01 -06:00
parent e8667f856f
commit 36ba08d3fa
10 changed files with 4475 additions and 21 deletions

View File

@@ -77,6 +77,7 @@ class ModelConfig:
self.is_xl: bool = kwargs.get('is_xl', False)
self.is_v_pred: bool = kwargs.get('is_v_pred', False)
self.dtype: str = kwargs.get('dtype', 'float16')
self.vae_path = kwargs.get('vae_path', None)
# only for SDXL models for now
self.use_text_encoder_1: bool = kwargs.get('use_text_encoder_1', True)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,43 @@
{
"ldm": {
"conditioner.embedders.0.transformer.text_model.embeddings.position_ids": {
"shape": [
1,
77
],
"min": 0.0,
"max": 76.0,
"mean": 38.0,
"std": 22.375
},
"conditioner.embedders.1.model.logit_scale": {
"shape": [],
"min": 4.60546875,
"max": 4.60546875,
"mean": 4.60546875,
"std": NaN
},
"conditioner.embedders.1.model.text_projection": {
"shape": [
1280,
1280
],
"min": -0.15966796875,
"max": 0.230712890625,
"mean": 0.0,
"std": 0.0181732177734375
}
},
"diffusers": {
"te1_text_projection.weight": {
"shape": [
1280,
1280
],
"min": -0.15966796875,
"max": 0.230712890625,
"mean": 2.128152846125886e-05,
"std": 0.018169498071074486
}
}
}

View File

@@ -4,6 +4,7 @@ TOOLKIT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
CONFIG_ROOT = os.path.join(TOOLKIT_ROOT, 'config')
SD_SCRIPTS_ROOT = os.path.join(TOOLKIT_ROOT, "repositories", "sd-scripts")
REPOS_ROOT = os.path.join(TOOLKIT_ROOT, "repositories")
KEYMAPS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "keymaps")
# check if ENV variable is set
if 'MODELS_PATH' in os.environ:

98
toolkit/saving.py Normal file
View File

@@ -0,0 +1,98 @@
import json
import os
from collections import OrderedDict
from typing import TYPE_CHECKING, Literal, Optional, Union
import torch
from safetensors.torch import load_file, save_file
from toolkit.train_tools import get_torch_dtype
from toolkit.paths import KEYMAPS_ROOT
if TYPE_CHECKING:
from toolkit.stable_diffusion_model import StableDiffusion
def get_slices_from_string(s: str) -> tuple:
slice_strings = s.split(',')
slices = [eval(f"slice({component.strip()})") for component in slice_strings]
return tuple(slices)
def convert_state_dict_to_ldm_with_mapping(
diffusers_state_dict: 'OrderedDict',
mapping_path: str,
base_path: Union[str, None] = None,
device: str = 'cpu',
dtype: torch.dtype = torch.float32
) -> 'OrderedDict':
converted_state_dict = OrderedDict()
# load mapping
with open(mapping_path, 'r') as f:
mapping = json.load(f, object_pairs_hook=OrderedDict)
ldm_diffusers_keymap = mapping['ldm_diffusers_keymap']
ldm_diffusers_shape_map = mapping['ldm_diffusers_shape_map']
ldm_diffusers_operator_map = mapping['ldm_diffusers_operator_map']
# load base if it exists
# the base just has come keys like timing ids and stuff diffusers doesn't have or they don't match
if base_path is not None:
converted_state_dict = load_file(base_path, device)
# convert to the right dtype
for key in converted_state_dict:
converted_state_dict[key] = converted_state_dict[key].to(device, dtype=dtype)
# process operators first
for ldm_key in ldm_diffusers_operator_map:
# if the key cat is in the ldm key, we need to process it
if 'cat' in ldm_key:
cat_list = []
for diffusers_key in ldm_diffusers_operator_map[ldm_key]['cat']:
cat_list.append(diffusers_state_dict[diffusers_key].detatch())
converted_state_dict[ldm_key] = torch.cat(cat_list, dim=0).to(device, dtype=dtype)
if 'slice' in ldm_key:
tensor_to_slice = diffusers_state_dict[ldm_diffusers_operator_map[ldm_key]['slice'][0]]
slice_text = diffusers_state_dict[ldm_diffusers_operator_map[ldm_key]['slice'][1]]
converted_state_dict[ldm_key] = tensor_to_slice[get_slices_from_string(slice_text)].detatch().to(device,
dtype=dtype)
# process the rest of the keys
for ldm_key in ldm_diffusers_keymap:
# if the key is in the ldm key, we need to process it
if ldm_diffusers_keymap[ldm_key] in diffusers_state_dict:
tensor = diffusers_state_dict[ldm_diffusers_keymap[ldm_key]].detach().to(device, dtype=dtype)
# see if we need to reshape
if ldm_key in ldm_diffusers_shape_map:
tensor = tensor.view(ldm_diffusers_shape_map[ldm_key][0])
converted_state_dict[ldm_key] = tensor
return converted_state_dict
def save_ldm_model_from_diffusers(
sd: 'StableDiffusion',
output_file: str,
meta: 'OrderedDict',
save_dtype=get_torch_dtype('fp16'),
sd_version: Literal['1', '2', 'sdxl'] = '2'
):
if sd_version != 'sdxl':
# not supported yet
raise NotImplementedError("Only SDXL is supported at this time with this method")
# load our base
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sdxl_ldm_base.safetensors')
mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sdxl.json')
# convert the state dict
converted_state_dict = convert_state_dict_to_ldm_with_mapping(
sd.state_dict(),
mapping_path,
base_path,
device='cpu',
dtype=save_dtype
)
# make sure parent folder exists
os.makedirs(os.path.dirname(output_file), exist_ok=True)
save_file(converted_state_dict, output_file, metadata=meta)

View File

@@ -1,8 +1,9 @@
import gc
import typing
from typing import Union, OrderedDict, List, Tuple
from typing import Union, List, Tuple
import sys
import os
from collections import OrderedDict
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
from safetensors.torch import save_file
@@ -10,11 +11,12 @@ from tqdm import tqdm
from torchvision.transforms import Resize
from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \
convert_vae_state_dict
convert_vae_state_dict, load_vae
from toolkit import train_tools
from toolkit.config_modules import ModelConfig, GenerateImageConfig
from toolkit.metadata import get_meta_for_safetensors
from toolkit.paths import REPOS_ROOT
from toolkit.saving import save_ldm_model_from_diffusers
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
import torch
from library import model_util
@@ -27,6 +29,13 @@ import diffusers
# tell it to shut up
diffusers.logging.set_verbosity(diffusers.logging.ERROR)
VAE_PREFIX_UNET = "vae"
SD_PREFIX_UNET = "unet"
SD_PREFIX_TEXT_ENCODER = "te"
SD_PREFIX_TEXT_ENCODER1 = "te1"
SD_PREFIX_TEXT_ENCODER2 = "te2"
class BlankNetwork:
multiplier = 1.0
@@ -218,6 +227,10 @@ class StableDiffusion:
# scheduler doesn't get set sometimes, so we set it here
pipe.scheduler = scheduler
if self.model_config.vae_path is not None:
external_vae = load_vae(self.model_config.vae_path, dtype)
pipe.vae = external_vae
self.unet = pipe.unet
self.noise_scheduler = pipe.scheduler
self.vae = pipe.vae.to(self.device_torch, dtype=dtype)
@@ -630,8 +643,33 @@ class StableDiffusion:
raise ValueError(f"Unknown weight name: {name}")
def state_dict(self, vae=True, text_encoder=True, unet=True):
state_dict = OrderedDict()
if vae:
for k, v in self.vae.state_dict().items():
new_key = k if k.startswith(f"{VAE_PREFIX_UNET}") else f"{VAE_PREFIX_UNET}_{k}"
state_dict[new_key] = v
if text_encoder:
if isinstance(self.text_encoder, list):
for i, encoder in enumerate(self.text_encoder):
for k, v in encoder.state_dict().items():
new_key = k if k.startswith(
f"{SD_PREFIX_TEXT_ENCODER}{i}") else f"{SD_PREFIX_TEXT_ENCODER}{i}_{k}"
state_dict[new_key] = v
else:
for k, v in self.text_encoder.state_dict().items():
new_key = k if k.startswith(f"{SD_PREFIX_TEXT_ENCODER}") else f"{SD_PREFIX_TEXT_ENCODER}_{k}"
state_dict[new_key] = v
if unet:
for k, v in self.unet.state_dict().items():
new_key = k if k.startswith(f"{SD_PREFIX_UNET}") else f"{SD_PREFIX_UNET}_{k}"
state_dict[new_key] = v
return state_dict
def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None):
state_dict = {}
# prepare metadata
meta = get_meta_for_safetensors(meta)
def update_sd(prefix, sd):
for k, v in sd.items():
@@ -644,14 +682,13 @@ class StableDiffusion:
# todo see what logit scale is
if self.is_xl:
# Convert the UNet model
update_sd("model.diffusion_model.", self.unet.state_dict())
# Convert the text encoders
update_sd("conditioner.embedders.0.transformer.", self.text_encoder[0].state_dict())
text_enc2_dict = convert_text_encoder_2_state_dict_to_sdxl(self.text_encoder[1].state_dict(), logit_scale)
update_sd("conditioner.embedders.1.model.", text_enc2_dict)
save_ldm_model_from_diffusers(
sd=self,
output_file=output_file,
meta=meta,
save_dtype=save_dtype,
sd_version='sdxl',
)
else:
# Convert the UNet model
@@ -667,13 +704,11 @@ class StableDiffusion:
text_enc_dict = self.text_encoder.state_dict()
update_sd("cond_stage_model.transformer.", text_enc_dict)
# Convert the VAE
if self.vae is not None:
vae_dict = model_util.convert_vae_state_dict(self.vae.state_dict())
update_sd("first_stage_model.", vae_dict)
# Convert the VAE
if self.vae is not None:
vae_dict = model_util.convert_vae_state_dict(self.vae.state_dict())
update_sd("first_stage_model.", vae_dict)
# prepare metadata
meta = get_meta_for_safetensors(meta)
# make sure parent folder exists
os.makedirs(os.path.dirname(output_file), exist_ok=True)
save_file(state_dict, output_file, metadata=meta)
# make sure parent folder exists
os.makedirs(os.path.dirname(output_file), exist_ok=True)
save_file(state_dict, output_file, metadata=meta)