Added a converter back to ldm from diffusers for sdxl. Can finally get to training it properly

This commit is contained in:
Jaret Burkett
2023-08-21 16:22:01 -06:00
parent e8667f856f
commit 36ba08d3fa
10 changed files with 4475 additions and 21 deletions

View File

@@ -95,7 +95,7 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
if weight_jitter > 0.0: if weight_jitter > 0.0:
jitter_list = random.uniform(-weight_jitter, weight_jitter) jitter_list = random.uniform(-weight_jitter, weight_jitter)
network_pos_weight += jitter_list network_pos_weight += jitter_list
network_neg_weight += jitter_list network_neg_weight += (jitter_list * -1.0)
# if items in network_weight list are tensors, convert them to floats # if items in network_weight list are tensors, convert them to floats

View File

@@ -248,7 +248,7 @@ class UltimateSliderTrainerProcess(BaseSDTrainProcess):
if weight_jitter > 0.0: if weight_jitter > 0.0:
jitter_list = random.uniform(-weight_jitter, weight_jitter) jitter_list = random.uniform(-weight_jitter, weight_jitter)
network_pos_weight += jitter_list network_pos_weight += jitter_list
network_neg_weight += jitter_list network_neg_weight += (jitter_list * -1.0)
# if items in network_weight list are tensors, convert them to floats # if items in network_weight list are tensors, convert them to floats
imgs: torch.Tensor = imgs.to(self.device_torch, dtype=dtype) imgs: torch.Tensor = imgs.to(self.device_torch, dtype=dtype)

View File

@@ -0,0 +1,332 @@
import argparse
import gc
import os
import re
import torch
from diffusers.loaders import LoraLoaderMixin
from safetensors.torch import load_file, save_file
from collections import OrderedDict
import json
from tqdm import tqdm
from toolkit.config_modules import ModelConfig
from toolkit.stable_diffusion_model import StableDiffusion
KEYMAPS_FOLDER = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'toolkit', 'keymaps')
device = torch.device('cpu')
dtype = torch.float32
def flush():
torch.cuda.empty_cache()
gc.collect()
def get_reduced_shape(shape_tuple):
# iterate though shape anr remove 1s
new_shape = []
for dim in shape_tuple:
if dim != 1:
new_shape.append(dim)
return tuple(new_shape)
parser = argparse.ArgumentParser()
# require at lease one config file
parser.add_argument(
'file_1',
nargs='+',
type=str,
help='Path to first safe tensor file'
)
parser.add_argument('--name', type=str, default='stable_diffusion', help='name for mapping to make')
parser.add_argument('--sdxl', action='store_true', help='is sdxl model')
parser.add_argument('--sd2', action='store_true', help='is sd 2 model')
args = parser.parse_args()
file_path = args.file_1[0]
find_matches = False
print(f'Loading diffusers model')
diffusers_model_config = ModelConfig(
name_or_path=file_path,
is_xl=args.sdxl,
is_v2=args.sd2,
dtype=dtype,
)
diffusers_sd = StableDiffusion(
model_config=diffusers_model_config,
device=device,
dtype=dtype,
)
diffusers_sd.load_model()
# delete things we dont need
del diffusers_sd.tokenizer
flush()
print(f'Loading ldm model')
diffusers_state_dict = diffusers_sd.state_dict()
diffusers_dict_keys = list(diffusers_state_dict.keys())
ldm_state_dict = load_file(file_path)
ldm_dict_keys = list(ldm_state_dict.keys())
ldm_diffusers_keymap = OrderedDict()
ldm_diffusers_shape_map = OrderedDict()
ldm_operator_map = OrderedDict()
diffusers_operator_map = OrderedDict()
total_keys = len(ldm_dict_keys)
matched_ldm_keys = []
matched_diffusers_keys = []
error_margin = 1e-4
if args.sdxl:
# do pre known merging
for ldm_key in ldm_dict_keys:
pattern = r"conditioner\.embedders\.1\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_weight"
match = re.match(pattern, ldm_key)
if match:
number = int(match.group(1))
new_val = torch.cat([
diffusers_state_dict[f"te1_text_model.encoder.layers.{number}.self_attn.q_proj.weight"],
diffusers_state_dict[f"te1_text_model.encoder.layers.{number}.self_attn.k_proj.weight"],
diffusers_state_dict[f"te1_text_model.encoder.layers.{number}.self_attn.v_proj.weight"],
], dim=0)
# add to matched so we dont check them
matched_diffusers_keys.append(f"te1_text_model.encoder.layers.{number}.self_attn.q_proj.weight")
matched_diffusers_keys.append(f"te1_text_model.encoder.layers.{number}.self_attn.k_proj.weight")
matched_diffusers_keys.append(f"te1_text_model.encoder.layers.{number}.self_attn.v_proj.weight")
# make diffusers convertable_dict
diffusers_state_dict[f"te1_text_model.encoder.layers.{number}.self_attn.MERGED.weight"] = new_val
# add operator
ldm_operator_map[ldm_key] = {
"cat": [
f"te1_text_model.encoder.layers.{number}.self_attn.q_proj.weight",
f"te1_text_model.encoder.layers.{number}.self_attn.k_proj.weight",
f"te1_text_model.encoder.layers.{number}.self_attn.v_proj.weight",
],
"target": f"te1_text_model.encoder.layers.{number}.self_attn.MERGED.weight"
}
if "conditioner.embedders.1.model.text_projection" in ldm_dict_keys:
# d_model = int(checkpoint[prefix + "text_projection"].shape[0]))
d_model = int(ldm_state_dict["conditioner.embedders.1.model.text_projection"].shape[0])
else:
d_model = 1024
# text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :]
# text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model: d_model * 2, :]
# text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2:, :]
# add diffusers operators
diffusers_operator_map[f"te1_text_model.encoder.layers.{number}.self_attn.q_proj.weight"] = {
"slice": [
f"conditioner.embedders.1.model.transformer.resblocks.{number}.attn.in_proj_weight",
f"0:{d_model}, :"
]
}
diffusers_operator_map[f"te1_text_model.encoder.layers.{number}.self_attn.k_proj.weight"] = {
"slice": [
f"conditioner.embedders.1.model.transformer.resblocks.{number}.attn.in_proj_weight",
f"{d_model}:{d_model * 2}, :"
]
}
diffusers_operator_map[f"te1_text_model.encoder.layers.{number}.self_attn.v_proj.weight"] = {
"slice": [
f"conditioner.embedders.1.model.transformer.resblocks.{number}.attn.in_proj_weight",
f"{d_model * 2}:, :"
]
}
pattern = r"conditioner\.embedders\.1\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_bias"
match = re.match(pattern, ldm_key)
if match:
number = int(match.group(1))
new_val = torch.cat([
diffusers_state_dict[f"te1_text_model.encoder.layers.{number}.self_attn.q_proj.bias"],
diffusers_state_dict[f"te1_text_model.encoder.layers.{number}.self_attn.k_proj.bias"],
diffusers_state_dict[f"te1_text_model.encoder.layers.{number}.self_attn.v_proj.bias"],
], dim=0)
# add to matched so we dont check them
matched_diffusers_keys.append(f"te1_text_model.encoder.layers.{number}.self_attn.q_proj.bias")
matched_diffusers_keys.append(f"te1_text_model.encoder.layers.{number}.self_attn.k_proj.bias")
matched_diffusers_keys.append(f"te1_text_model.encoder.layers.{number}.self_attn.v_proj.bias")
# make diffusers convertable_dict
diffusers_state_dict[f"te1_text_model.encoder.layers.{number}.self_attn.MERGED.bias"] = new_val
# add operator
ldm_operator_map[ldm_key] = {
"cat": [
f"te1_text_model.encoder.layers.{number}.self_attn.q_proj.bias",
f"te1_text_model.encoder.layers.{number}.self_attn.k_proj.bias",
f"te1_text_model.encoder.layers.{number}.self_attn.v_proj.bias",
],
"target": f"te1_text_model.encoder.layers.{number}.self_attn.MERGED.bias"
}
# update keys
diffusers_dict_keys = list(diffusers_state_dict.keys())
pbar = tqdm(ldm_dict_keys, desc='Matching ldm-diffusers keys', total=total_keys)
# run through all weights and check mse between them to find matches
for ldm_key in ldm_dict_keys:
ldm_shape_tuple = ldm_state_dict[ldm_key].shape
ldm_reduced_shape_tuple = get_reduced_shape(ldm_shape_tuple)
for diffusers_key in diffusers_dict_keys:
diffusers_shape_tuple = diffusers_state_dict[diffusers_key].shape
diffusers_reduced_shape_tuple = get_reduced_shape(diffusers_shape_tuple)
# That was easy. Same key
if ldm_key == diffusers_key:
ldm_diffusers_keymap[ldm_key] = diffusers_key
matched_ldm_keys.append(ldm_key)
matched_diffusers_keys.append(diffusers_key)
break
# if we already have this key mapped, skip it
if diffusers_key in matched_diffusers_keys:
continue
# if reduced shapes do not match skip it
if ldm_reduced_shape_tuple != diffusers_reduced_shape_tuple:
continue
ldm_weight = ldm_state_dict[ldm_key]
did_reduce_ldm = False
diffusers_weight = diffusers_state_dict[diffusers_key]
did_reduce_diffusers = False
# reduce the shapes to match if they are not the same
if ldm_shape_tuple != ldm_reduced_shape_tuple:
ldm_weight = ldm_weight.view(ldm_reduced_shape_tuple)
did_reduce_ldm = True
if diffusers_shape_tuple != diffusers_reduced_shape_tuple:
diffusers_weight = diffusers_weight.view(diffusers_reduced_shape_tuple)
did_reduce_diffusers = True
# check to see if they match within a margin of error
mse = torch.nn.functional.mse_loss(ldm_weight, diffusers_weight)
if mse < error_margin:
ldm_diffusers_keymap[ldm_key] = diffusers_key
matched_ldm_keys.append(ldm_key)
matched_diffusers_keys.append(diffusers_key)
if did_reduce_ldm or did_reduce_diffusers:
ldm_diffusers_shape_map[ldm_key] = (ldm_shape_tuple, diffusers_shape_tuple)
if did_reduce_ldm:
del ldm_weight
if did_reduce_diffusers:
del diffusers_weight
flush()
break
pbar.update(1)
pbar.close()
name = args.name
if args.sdxl:
name += '_sdxl'
elif args.sd2:
name += '_sd2'
else:
name += '_sd1'
# if len(matched_ldm_keys) != len(matched_diffusers_keys):
unmatched_ldm_keys = [x for x in ldm_dict_keys if x not in matched_ldm_keys]
unmatched_diffusers_keys = [x for x in diffusers_dict_keys if x not in matched_diffusers_keys]
# has unmatched keys
has_unmatched_keys = len(unmatched_ldm_keys) > 0 or len(unmatched_diffusers_keys) > 0
def get_slices_from_string(s: str) -> tuple:
slice_strings = s.split(',')
slices = [eval(f"slice({component.strip()})") for component in slice_strings]
return tuple(slices)
if has_unmatched_keys:
print(
f"Found {len(unmatched_ldm_keys)} unmatched ldm keys and {len(unmatched_diffusers_keys)} unmatched diffusers keys")
unmatched_obj = OrderedDict()
unmatched_obj['ldm'] = OrderedDict()
unmatched_obj['diffusers'] = OrderedDict()
print(f"Gathering info on unmatched keys")
for key in tqdm(unmatched_ldm_keys, desc='Unmatched LDM keys'):
# get min, max, mean, std
weight = ldm_state_dict[key]
weight_min = weight.min().item()
weight_max = weight.max().item()
weight_mean = weight.mean().item()
weight_std = weight.std().item()
unmatched_obj['ldm'][key] = {
'shape': weight.shape,
"min": weight_min,
"max": weight_max,
"mean": weight_mean,
"std": weight_std,
}
del weight
flush()
for key in tqdm(unmatched_diffusers_keys, desc='Unmatched Diffusers keys'):
# get min, max, mean, std
weight = diffusers_state_dict[key]
weight_min = weight.min().item()
weight_max = weight.max().item()
weight_mean = weight.mean().item()
weight_std = weight.std().item()
unmatched_obj['diffusers'][key] = {
"shape": weight.shape,
"min": weight_min,
"max": weight_max,
"mean": weight_mean,
"std": weight_std,
}
del weight
flush()
unmatched_path = os.path.join(KEYMAPS_FOLDER, f'{name}_unmatched.json')
with open(unmatched_path, 'w') as f:
f.write(json.dumps(unmatched_obj, indent=4))
print(f'Saved unmatched keys to {unmatched_path}')
# save ldm remainders
remaining_ldm_values = OrderedDict()
for key in unmatched_ldm_keys:
remaining_ldm_values[key] = ldm_state_dict[key].detach().to('cpu', torch.float16)
save_file(remaining_ldm_values, os.path.join(KEYMAPS_FOLDER, f'{name}_ldm_base.safetensors'))
print(f'Saved remaining ldm values to {os.path.join(KEYMAPS_FOLDER, f"{name}_ldm_base.safetensors")}')
dest_path = os.path.join(KEYMAPS_FOLDER, f'{name}.json')
save_obj = OrderedDict()
save_obj["ldm_diffusers_keymap"] = ldm_diffusers_keymap
save_obj["ldm_diffusers_shape_map"] = ldm_diffusers_shape_map
save_obj["ldm_diffusers_operator_map"] = ldm_operator_map
save_obj["diffusers_ldm_operator_map"] = diffusers_operator_map
with open(dest_path, 'w') as f:
f.write(json.dumps(save_obj, indent=4))
print(f'Saved keymap to {dest_path}')

View File

@@ -77,6 +77,7 @@ class ModelConfig:
self.is_xl: bool = kwargs.get('is_xl', False) self.is_xl: bool = kwargs.get('is_xl', False)
self.is_v_pred: bool = kwargs.get('is_v_pred', False) self.is_v_pred: bool = kwargs.get('is_v_pred', False)
self.dtype: str = kwargs.get('dtype', 'float16') self.dtype: str = kwargs.get('dtype', 'float16')
self.vae_path = kwargs.get('vae_path', None)
# only for SDXL models for now # only for SDXL models for now
self.use_text_encoder_1: bool = kwargs.get('use_text_encoder_1', True) self.use_text_encoder_1: bool = kwargs.get('use_text_encoder_1', True)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,43 @@
{
"ldm": {
"conditioner.embedders.0.transformer.text_model.embeddings.position_ids": {
"shape": [
1,
77
],
"min": 0.0,
"max": 76.0,
"mean": 38.0,
"std": 22.375
},
"conditioner.embedders.1.model.logit_scale": {
"shape": [],
"min": 4.60546875,
"max": 4.60546875,
"mean": 4.60546875,
"std": NaN
},
"conditioner.embedders.1.model.text_projection": {
"shape": [
1280,
1280
],
"min": -0.15966796875,
"max": 0.230712890625,
"mean": 0.0,
"std": 0.0181732177734375
}
},
"diffusers": {
"te1_text_projection.weight": {
"shape": [
1280,
1280
],
"min": -0.15966796875,
"max": 0.230712890625,
"mean": 2.128152846125886e-05,
"std": 0.018169498071074486
}
}
}

View File

@@ -4,6 +4,7 @@ TOOLKIT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
CONFIG_ROOT = os.path.join(TOOLKIT_ROOT, 'config') CONFIG_ROOT = os.path.join(TOOLKIT_ROOT, 'config')
SD_SCRIPTS_ROOT = os.path.join(TOOLKIT_ROOT, "repositories", "sd-scripts") SD_SCRIPTS_ROOT = os.path.join(TOOLKIT_ROOT, "repositories", "sd-scripts")
REPOS_ROOT = os.path.join(TOOLKIT_ROOT, "repositories") REPOS_ROOT = os.path.join(TOOLKIT_ROOT, "repositories")
KEYMAPS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "keymaps")
# check if ENV variable is set # check if ENV variable is set
if 'MODELS_PATH' in os.environ: if 'MODELS_PATH' in os.environ:

98
toolkit/saving.py Normal file
View File

@@ -0,0 +1,98 @@
import json
import os
from collections import OrderedDict
from typing import TYPE_CHECKING, Literal, Optional, Union
import torch
from safetensors.torch import load_file, save_file
from toolkit.train_tools import get_torch_dtype
from toolkit.paths import KEYMAPS_ROOT
if TYPE_CHECKING:
from toolkit.stable_diffusion_model import StableDiffusion
def get_slices_from_string(s: str) -> tuple:
slice_strings = s.split(',')
slices = [eval(f"slice({component.strip()})") for component in slice_strings]
return tuple(slices)
def convert_state_dict_to_ldm_with_mapping(
diffusers_state_dict: 'OrderedDict',
mapping_path: str,
base_path: Union[str, None] = None,
device: str = 'cpu',
dtype: torch.dtype = torch.float32
) -> 'OrderedDict':
converted_state_dict = OrderedDict()
# load mapping
with open(mapping_path, 'r') as f:
mapping = json.load(f, object_pairs_hook=OrderedDict)
ldm_diffusers_keymap = mapping['ldm_diffusers_keymap']
ldm_diffusers_shape_map = mapping['ldm_diffusers_shape_map']
ldm_diffusers_operator_map = mapping['ldm_diffusers_operator_map']
# load base if it exists
# the base just has come keys like timing ids and stuff diffusers doesn't have or they don't match
if base_path is not None:
converted_state_dict = load_file(base_path, device)
# convert to the right dtype
for key in converted_state_dict:
converted_state_dict[key] = converted_state_dict[key].to(device, dtype=dtype)
# process operators first
for ldm_key in ldm_diffusers_operator_map:
# if the key cat is in the ldm key, we need to process it
if 'cat' in ldm_key:
cat_list = []
for diffusers_key in ldm_diffusers_operator_map[ldm_key]['cat']:
cat_list.append(diffusers_state_dict[diffusers_key].detatch())
converted_state_dict[ldm_key] = torch.cat(cat_list, dim=0).to(device, dtype=dtype)
if 'slice' in ldm_key:
tensor_to_slice = diffusers_state_dict[ldm_diffusers_operator_map[ldm_key]['slice'][0]]
slice_text = diffusers_state_dict[ldm_diffusers_operator_map[ldm_key]['slice'][1]]
converted_state_dict[ldm_key] = tensor_to_slice[get_slices_from_string(slice_text)].detatch().to(device,
dtype=dtype)
# process the rest of the keys
for ldm_key in ldm_diffusers_keymap:
# if the key is in the ldm key, we need to process it
if ldm_diffusers_keymap[ldm_key] in diffusers_state_dict:
tensor = diffusers_state_dict[ldm_diffusers_keymap[ldm_key]].detach().to(device, dtype=dtype)
# see if we need to reshape
if ldm_key in ldm_diffusers_shape_map:
tensor = tensor.view(ldm_diffusers_shape_map[ldm_key][0])
converted_state_dict[ldm_key] = tensor
return converted_state_dict
def save_ldm_model_from_diffusers(
sd: 'StableDiffusion',
output_file: str,
meta: 'OrderedDict',
save_dtype=get_torch_dtype('fp16'),
sd_version: Literal['1', '2', 'sdxl'] = '2'
):
if sd_version != 'sdxl':
# not supported yet
raise NotImplementedError("Only SDXL is supported at this time with this method")
# load our base
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sdxl_ldm_base.safetensors')
mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sdxl.json')
# convert the state dict
converted_state_dict = convert_state_dict_to_ldm_with_mapping(
sd.state_dict(),
mapping_path,
base_path,
device='cpu',
dtype=save_dtype
)
# make sure parent folder exists
os.makedirs(os.path.dirname(output_file), exist_ok=True)
save_file(converted_state_dict, output_file, metadata=meta)

View File

@@ -1,8 +1,9 @@
import gc import gc
import typing import typing
from typing import Union, OrderedDict, List, Tuple from typing import Union, List, Tuple
import sys import sys
import os import os
from collections import OrderedDict
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
from safetensors.torch import save_file from safetensors.torch import save_file
@@ -10,11 +11,12 @@ from tqdm import tqdm
from torchvision.transforms import Resize from torchvision.transforms import Resize
from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \ from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \
convert_vae_state_dict convert_vae_state_dict, load_vae
from toolkit import train_tools from toolkit import train_tools
from toolkit.config_modules import ModelConfig, GenerateImageConfig from toolkit.config_modules import ModelConfig, GenerateImageConfig
from toolkit.metadata import get_meta_for_safetensors from toolkit.metadata import get_meta_for_safetensors
from toolkit.paths import REPOS_ROOT from toolkit.paths import REPOS_ROOT
from toolkit.saving import save_ldm_model_from_diffusers
from toolkit.train_tools import get_torch_dtype, apply_noise_offset from toolkit.train_tools import get_torch_dtype, apply_noise_offset
import torch import torch
from library import model_util from library import model_util
@@ -27,6 +29,13 @@ import diffusers
# tell it to shut up # tell it to shut up
diffusers.logging.set_verbosity(diffusers.logging.ERROR) diffusers.logging.set_verbosity(diffusers.logging.ERROR)
VAE_PREFIX_UNET = "vae"
SD_PREFIX_UNET = "unet"
SD_PREFIX_TEXT_ENCODER = "te"
SD_PREFIX_TEXT_ENCODER1 = "te1"
SD_PREFIX_TEXT_ENCODER2 = "te2"
class BlankNetwork: class BlankNetwork:
multiplier = 1.0 multiplier = 1.0
@@ -218,6 +227,10 @@ class StableDiffusion:
# scheduler doesn't get set sometimes, so we set it here # scheduler doesn't get set sometimes, so we set it here
pipe.scheduler = scheduler pipe.scheduler = scheduler
if self.model_config.vae_path is not None:
external_vae = load_vae(self.model_config.vae_path, dtype)
pipe.vae = external_vae
self.unet = pipe.unet self.unet = pipe.unet
self.noise_scheduler = pipe.scheduler self.noise_scheduler = pipe.scheduler
self.vae = pipe.vae.to(self.device_torch, dtype=dtype) self.vae = pipe.vae.to(self.device_torch, dtype=dtype)
@@ -630,8 +643,33 @@ class StableDiffusion:
raise ValueError(f"Unknown weight name: {name}") raise ValueError(f"Unknown weight name: {name}")
def state_dict(self, vae=True, text_encoder=True, unet=True):
state_dict = OrderedDict()
if vae:
for k, v in self.vae.state_dict().items():
new_key = k if k.startswith(f"{VAE_PREFIX_UNET}") else f"{VAE_PREFIX_UNET}_{k}"
state_dict[new_key] = v
if text_encoder:
if isinstance(self.text_encoder, list):
for i, encoder in enumerate(self.text_encoder):
for k, v in encoder.state_dict().items():
new_key = k if k.startswith(
f"{SD_PREFIX_TEXT_ENCODER}{i}") else f"{SD_PREFIX_TEXT_ENCODER}{i}_{k}"
state_dict[new_key] = v
else:
for k, v in self.text_encoder.state_dict().items():
new_key = k if k.startswith(f"{SD_PREFIX_TEXT_ENCODER}") else f"{SD_PREFIX_TEXT_ENCODER}_{k}"
state_dict[new_key] = v
if unet:
for k, v in self.unet.state_dict().items():
new_key = k if k.startswith(f"{SD_PREFIX_UNET}") else f"{SD_PREFIX_UNET}_{k}"
state_dict[new_key] = v
return state_dict
def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None): def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None):
state_dict = {} state_dict = {}
# prepare metadata
meta = get_meta_for_safetensors(meta)
def update_sd(prefix, sd): def update_sd(prefix, sd):
for k, v in sd.items(): for k, v in sd.items():
@@ -644,14 +682,13 @@ class StableDiffusion:
# todo see what logit scale is # todo see what logit scale is
if self.is_xl: if self.is_xl:
# Convert the UNet model save_ldm_model_from_diffusers(
update_sd("model.diffusion_model.", self.unet.state_dict()) sd=self,
output_file=output_file,
# Convert the text encoders meta=meta,
update_sd("conditioner.embedders.0.transformer.", self.text_encoder[0].state_dict()) save_dtype=save_dtype,
sd_version='sdxl',
text_enc2_dict = convert_text_encoder_2_state_dict_to_sdxl(self.text_encoder[1].state_dict(), logit_scale) )
update_sd("conditioner.embedders.1.model.", text_enc2_dict)
else: else:
# Convert the UNet model # Convert the UNet model
@@ -667,13 +704,11 @@ class StableDiffusion:
text_enc_dict = self.text_encoder.state_dict() text_enc_dict = self.text_encoder.state_dict()
update_sd("cond_stage_model.transformer.", text_enc_dict) update_sd("cond_stage_model.transformer.", text_enc_dict)
# Convert the VAE # Convert the VAE
if self.vae is not None: if self.vae is not None:
vae_dict = model_util.convert_vae_state_dict(self.vae.state_dict()) vae_dict = model_util.convert_vae_state_dict(self.vae.state_dict())
update_sd("first_stage_model.", vae_dict) update_sd("first_stage_model.", vae_dict)
# prepare metadata # make sure parent folder exists
meta = get_meta_for_safetensors(meta) os.makedirs(os.path.dirname(output_file), exist_ok=True)
# make sure parent folder exists save_file(state_dict, output_file, metadata=meta)
os.makedirs(os.path.dirname(output_file), exist_ok=True)
save_file(state_dict, output_file, metadata=meta)