mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added converters for all stable diffusion models to convert back to ldm format from diffusers.
This commit is contained in:
@@ -2,6 +2,11 @@ import argparse
|
||||
import gc
|
||||
import os
|
||||
import re
|
||||
import os
|
||||
# add project root to sys path
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import torch
|
||||
from diffusers.loaders import LoraLoaderMixin
|
||||
@@ -90,90 +95,134 @@ matched_diffusers_keys = []
|
||||
|
||||
error_margin = 1e-4
|
||||
|
||||
te_suffix = ''
|
||||
proj_pattern_weight = None
|
||||
proj_pattern_bias = None
|
||||
text_proj_layer = None
|
||||
if args.sdxl:
|
||||
te_suffix = '1'
|
||||
ldm_res_block_prefix = "conditioner.embedders.1.model.transformer.resblocks"
|
||||
proj_pattern_weight = r"conditioner\.embedders\.1\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_weight"
|
||||
proj_pattern_bias = r"conditioner\.embedders\.1\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_bias"
|
||||
text_proj_layer = "conditioner.embedders.1.model.text_projection"
|
||||
if args.sd2:
|
||||
te_suffix = ''
|
||||
ldm_res_block_prefix = "cond_stage_model.model.transformer.resblocks"
|
||||
proj_pattern_weight = r"cond_stage_model\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_weight"
|
||||
proj_pattern_bias = r"cond_stage_model\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_bias"
|
||||
text_proj_layer = "cond_stage_model.model.text_projection"
|
||||
|
||||
if args.sdxl or args.sd2:
|
||||
if "conditioner.embedders.1.model.text_projection" in ldm_dict_keys:
|
||||
# d_model = int(checkpoint[prefix + "text_projection"].shape[0]))
|
||||
d_model = int(ldm_state_dict["conditioner.embedders.1.model.text_projection"].shape[0])
|
||||
else:
|
||||
d_model = 1024
|
||||
|
||||
# do pre known merging
|
||||
for ldm_key in ldm_dict_keys:
|
||||
pattern = r"conditioner\.embedders\.1\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_weight"
|
||||
match = re.match(pattern, ldm_key)
|
||||
if match:
|
||||
number = int(match.group(1))
|
||||
new_val = torch.cat([
|
||||
diffusers_state_dict[f"te1_text_model.encoder.layers.{number}.self_attn.q_proj.weight"],
|
||||
diffusers_state_dict[f"te1_text_model.encoder.layers.{number}.self_attn.k_proj.weight"],
|
||||
diffusers_state_dict[f"te1_text_model.encoder.layers.{number}.self_attn.v_proj.weight"],
|
||||
], dim=0)
|
||||
# add to matched so we dont check them
|
||||
matched_diffusers_keys.append(f"te1_text_model.encoder.layers.{number}.self_attn.q_proj.weight")
|
||||
matched_diffusers_keys.append(f"te1_text_model.encoder.layers.{number}.self_attn.k_proj.weight")
|
||||
matched_diffusers_keys.append(f"te1_text_model.encoder.layers.{number}.self_attn.v_proj.weight")
|
||||
# make diffusers convertable_dict
|
||||
diffusers_state_dict[f"te1_text_model.encoder.layers.{number}.self_attn.MERGED.weight"] = new_val
|
||||
try:
|
||||
match = re.match(proj_pattern_weight, ldm_key)
|
||||
if match:
|
||||
number = int(match.group(1))
|
||||
new_val = torch.cat([
|
||||
diffusers_state_dict[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.q_proj.weight"],
|
||||
diffusers_state_dict[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.k_proj.weight"],
|
||||
diffusers_state_dict[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.v_proj.weight"],
|
||||
], dim=0)
|
||||
# add to matched so we dont check them
|
||||
matched_diffusers_keys.append(
|
||||
f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.q_proj.weight")
|
||||
matched_diffusers_keys.append(
|
||||
f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.k_proj.weight")
|
||||
matched_diffusers_keys.append(
|
||||
f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.v_proj.weight")
|
||||
# make diffusers convertable_dict
|
||||
diffusers_state_dict[
|
||||
f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.MERGED.weight"] = new_val
|
||||
|
||||
# add operator
|
||||
ldm_operator_map[ldm_key] = {
|
||||
"cat": [
|
||||
f"te1_text_model.encoder.layers.{number}.self_attn.q_proj.weight",
|
||||
f"te1_text_model.encoder.layers.{number}.self_attn.k_proj.weight",
|
||||
f"te1_text_model.encoder.layers.{number}.self_attn.v_proj.weight",
|
||||
],
|
||||
"target": f"te1_text_model.encoder.layers.{number}.self_attn.MERGED.weight"
|
||||
}
|
||||
# add operator
|
||||
ldm_operator_map[ldm_key] = {
|
||||
"cat": [
|
||||
f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.q_proj.weight",
|
||||
f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.k_proj.weight",
|
||||
f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.v_proj.weight",
|
||||
],
|
||||
"target": f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.MERGED.weight"
|
||||
}
|
||||
|
||||
if "conditioner.embedders.1.model.text_projection" in ldm_dict_keys:
|
||||
# d_model = int(checkpoint[prefix + "text_projection"].shape[0]))
|
||||
d_model = int(ldm_state_dict["conditioner.embedders.1.model.text_projection"].shape[0])
|
||||
else:
|
||||
d_model = 1024
|
||||
# text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :]
|
||||
# text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model: d_model * 2, :]
|
||||
# text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2:, :]
|
||||
|
||||
# text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :]
|
||||
# text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model: d_model * 2, :]
|
||||
# text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2:, :]
|
||||
# add diffusers operators
|
||||
diffusers_operator_map[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.q_proj.weight"] = {
|
||||
"slice": [
|
||||
f"{ldm_res_block_prefix}.{number}.attn.in_proj_weight",
|
||||
f"0:{d_model}, :"
|
||||
]
|
||||
}
|
||||
diffusers_operator_map[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.k_proj.weight"] = {
|
||||
"slice": [
|
||||
f"{ldm_res_block_prefix}.{number}.attn.in_proj_weight",
|
||||
f"{d_model}:{d_model * 2}, :"
|
||||
]
|
||||
}
|
||||
diffusers_operator_map[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.v_proj.weight"] = {
|
||||
"slice": [
|
||||
f"{ldm_res_block_prefix}.{number}.attn.in_proj_weight",
|
||||
f"{d_model * 2}:, :"
|
||||
]
|
||||
}
|
||||
|
||||
# add diffusers operators
|
||||
diffusers_operator_map[f"te1_text_model.encoder.layers.{number}.self_attn.q_proj.weight"] = {
|
||||
"slice": [
|
||||
f"conditioner.embedders.1.model.transformer.resblocks.{number}.attn.in_proj_weight",
|
||||
f"0:{d_model}, :"
|
||||
]
|
||||
}
|
||||
diffusers_operator_map[f"te1_text_model.encoder.layers.{number}.self_attn.k_proj.weight"] = {
|
||||
"slice": [
|
||||
f"conditioner.embedders.1.model.transformer.resblocks.{number}.attn.in_proj_weight",
|
||||
f"{d_model}:{d_model * 2}, :"
|
||||
]
|
||||
}
|
||||
diffusers_operator_map[f"te1_text_model.encoder.layers.{number}.self_attn.v_proj.weight"] = {
|
||||
"slice": [
|
||||
f"conditioner.embedders.1.model.transformer.resblocks.{number}.attn.in_proj_weight",
|
||||
f"{d_model * 2}:, :"
|
||||
]
|
||||
}
|
||||
match = re.match(proj_pattern_bias, ldm_key)
|
||||
if match:
|
||||
number = int(match.group(1))
|
||||
new_val = torch.cat([
|
||||
diffusers_state_dict[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.q_proj.bias"],
|
||||
diffusers_state_dict[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.k_proj.bias"],
|
||||
diffusers_state_dict[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.v_proj.bias"],
|
||||
], dim=0)
|
||||
# add to matched so we dont check them
|
||||
matched_diffusers_keys.append(f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.q_proj.bias")
|
||||
matched_diffusers_keys.append(f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.k_proj.bias")
|
||||
matched_diffusers_keys.append(f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.v_proj.bias")
|
||||
# make diffusers convertable_dict
|
||||
diffusers_state_dict[
|
||||
f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.MERGED.bias"] = new_val
|
||||
|
||||
pattern = r"conditioner\.embedders\.1\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_bias"
|
||||
match = re.match(pattern, ldm_key)
|
||||
if match:
|
||||
number = int(match.group(1))
|
||||
new_val = torch.cat([
|
||||
diffusers_state_dict[f"te1_text_model.encoder.layers.{number}.self_attn.q_proj.bias"],
|
||||
diffusers_state_dict[f"te1_text_model.encoder.layers.{number}.self_attn.k_proj.bias"],
|
||||
diffusers_state_dict[f"te1_text_model.encoder.layers.{number}.self_attn.v_proj.bias"],
|
||||
], dim=0)
|
||||
# add to matched so we dont check them
|
||||
matched_diffusers_keys.append(f"te1_text_model.encoder.layers.{number}.self_attn.q_proj.bias")
|
||||
matched_diffusers_keys.append(f"te1_text_model.encoder.layers.{number}.self_attn.k_proj.bias")
|
||||
matched_diffusers_keys.append(f"te1_text_model.encoder.layers.{number}.self_attn.v_proj.bias")
|
||||
# make diffusers convertable_dict
|
||||
diffusers_state_dict[f"te1_text_model.encoder.layers.{number}.self_attn.MERGED.bias"] = new_val
|
||||
# add operator
|
||||
ldm_operator_map[ldm_key] = {
|
||||
"cat": [
|
||||
f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.q_proj.bias",
|
||||
f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.k_proj.bias",
|
||||
f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.v_proj.bias",
|
||||
],
|
||||
# "target": f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.MERGED.bias"
|
||||
}
|
||||
|
||||
# add operator
|
||||
ldm_operator_map[ldm_key] = {
|
||||
"cat": [
|
||||
f"te1_text_model.encoder.layers.{number}.self_attn.q_proj.bias",
|
||||
f"te1_text_model.encoder.layers.{number}.self_attn.k_proj.bias",
|
||||
f"te1_text_model.encoder.layers.{number}.self_attn.v_proj.bias",
|
||||
],
|
||||
"target": f"te1_text_model.encoder.layers.{number}.self_attn.MERGED.bias"
|
||||
}
|
||||
# add diffusers operators
|
||||
diffusers_operator_map[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.q_proj.bias"] = {
|
||||
"slice": [
|
||||
f"{ldm_res_block_prefix}.{number}.attn.in_proj_bias",
|
||||
f"0:{d_model}, :"
|
||||
]
|
||||
}
|
||||
diffusers_operator_map[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.k_proj.bias"] = {
|
||||
"slice": [
|
||||
f"{ldm_res_block_prefix}.{number}.attn.in_proj_bias",
|
||||
f"{d_model}:{d_model * 2}, :"
|
||||
]
|
||||
}
|
||||
diffusers_operator_map[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.v_proj.bias"] = {
|
||||
"slice": [
|
||||
f"{ldm_res_block_prefix}.{number}.attn.in_proj_bias",
|
||||
f"{d_model * 2}:, :"
|
||||
]
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"Error on key {ldm_key}")
|
||||
print(e)
|
||||
|
||||
# update keys
|
||||
diffusers_dict_keys = list(diffusers_state_dict.keys())
|
||||
@@ -275,14 +324,10 @@ if has_unmatched_keys:
|
||||
weight = ldm_state_dict[key]
|
||||
weight_min = weight.min().item()
|
||||
weight_max = weight.max().item()
|
||||
weight_mean = weight.mean().item()
|
||||
weight_std = weight.std().item()
|
||||
unmatched_obj['ldm'][key] = {
|
||||
'shape': weight.shape,
|
||||
"min": weight_min,
|
||||
"max": weight_max,
|
||||
"mean": weight_mean,
|
||||
"std": weight_std,
|
||||
}
|
||||
del weight
|
||||
flush()
|
||||
@@ -292,14 +337,10 @@ if has_unmatched_keys:
|
||||
weight = diffusers_state_dict[key]
|
||||
weight_min = weight.min().item()
|
||||
weight_max = weight.max().item()
|
||||
weight_mean = weight.mean().item()
|
||||
weight_std = weight.std().item()
|
||||
unmatched_obj['diffusers'][key] = {
|
||||
"shape": weight.shape,
|
||||
"min": weight_min,
|
||||
"max": weight_max,
|
||||
"mean": weight_mean,
|
||||
"std": weight_std,
|
||||
}
|
||||
del weight
|
||||
flush()
|
||||
@@ -318,7 +359,6 @@ for key in unmatched_ldm_keys:
|
||||
save_file(remaining_ldm_values, os.path.join(KEYMAPS_FOLDER, f'{name}_ldm_base.safetensors'))
|
||||
print(f'Saved remaining ldm values to {os.path.join(KEYMAPS_FOLDER, f"{name}_ldm_base.safetensors")}')
|
||||
|
||||
|
||||
dest_path = os.path.join(KEYMAPS_FOLDER, f'{name}.json')
|
||||
save_obj = OrderedDict()
|
||||
save_obj["ldm_diffusers_keymap"] = ldm_diffusers_keymap
|
||||
|
||||
@@ -13,7 +13,7 @@ import json
|
||||
|
||||
from toolkit.config_modules import ModelConfig
|
||||
from toolkit.paths import KEYMAPS_ROOT
|
||||
from toolkit.saving import convert_state_dict_to_ldm_with_mapping
|
||||
from toolkit.saving import convert_state_dict_to_ldm_with_mapping, get_ldm_state_dict_from_diffusers
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
|
||||
# this was just used to match the vae keys to the diffusers keys
|
||||
@@ -39,6 +39,12 @@ parser.add_argument(
|
||||
help='Is the model an XL model'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--is_v2',
|
||||
action='store_true',
|
||||
help='Is the model a v2 model'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
find_matches = False
|
||||
@@ -58,19 +64,20 @@ sd = StableDiffusion(
|
||||
)
|
||||
sd.load_model()
|
||||
|
||||
if not args.is_xl:
|
||||
# not supported yet
|
||||
raise NotImplementedError("Only SDXL is supported at this time with this method")
|
||||
# load our base
|
||||
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sdxl_ldm_base.safetensors')
|
||||
mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sdxl.json')
|
||||
|
||||
print("Converting model back to LDM")
|
||||
version_string = '1'
|
||||
if args.is_v2:
|
||||
version_string = '2'
|
||||
if args.is_xl:
|
||||
version_string = 'sdxl'
|
||||
# convert the state dict
|
||||
state_dict_file_2 = convert_state_dict_to_ldm_with_mapping(
|
||||
state_dict_file_2 = get_ldm_state_dict_from_diffusers(
|
||||
sd.state_dict(),
|
||||
mapping_path,
|
||||
base_path,
|
||||
version_string,
|
||||
device='cpu',
|
||||
dtype=dtype
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user