mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-11 00:10:24 +00:00
Added converters for all stable diffusion models to convert back to ldm format from diffusers.
This commit is contained in:
@@ -2,6 +2,11 @@ import argparse
|
|||||||
import gc
|
import gc
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
import os
|
||||||
|
# add project root to sys path
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers.loaders import LoraLoaderMixin
|
from diffusers.loaders import LoraLoaderMixin
|
||||||
@@ -90,91 +95,135 @@ matched_diffusers_keys = []
|
|||||||
|
|
||||||
error_margin = 1e-4
|
error_margin = 1e-4
|
||||||
|
|
||||||
|
te_suffix = ''
|
||||||
|
proj_pattern_weight = None
|
||||||
|
proj_pattern_bias = None
|
||||||
|
text_proj_layer = None
|
||||||
if args.sdxl:
|
if args.sdxl:
|
||||||
# do pre known merging
|
te_suffix = '1'
|
||||||
for ldm_key in ldm_dict_keys:
|
ldm_res_block_prefix = "conditioner.embedders.1.model.transformer.resblocks"
|
||||||
pattern = r"conditioner\.embedders\.1\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_weight"
|
proj_pattern_weight = r"conditioner\.embedders\.1\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_weight"
|
||||||
match = re.match(pattern, ldm_key)
|
proj_pattern_bias = r"conditioner\.embedders\.1\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_bias"
|
||||||
if match:
|
text_proj_layer = "conditioner.embedders.1.model.text_projection"
|
||||||
number = int(match.group(1))
|
if args.sd2:
|
||||||
new_val = torch.cat([
|
te_suffix = ''
|
||||||
diffusers_state_dict[f"te1_text_model.encoder.layers.{number}.self_attn.q_proj.weight"],
|
ldm_res_block_prefix = "cond_stage_model.model.transformer.resblocks"
|
||||||
diffusers_state_dict[f"te1_text_model.encoder.layers.{number}.self_attn.k_proj.weight"],
|
proj_pattern_weight = r"cond_stage_model\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_weight"
|
||||||
diffusers_state_dict[f"te1_text_model.encoder.layers.{number}.self_attn.v_proj.weight"],
|
proj_pattern_bias = r"cond_stage_model\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_bias"
|
||||||
], dim=0)
|
text_proj_layer = "cond_stage_model.model.text_projection"
|
||||||
# add to matched so we dont check them
|
|
||||||
matched_diffusers_keys.append(f"te1_text_model.encoder.layers.{number}.self_attn.q_proj.weight")
|
|
||||||
matched_diffusers_keys.append(f"te1_text_model.encoder.layers.{number}.self_attn.k_proj.weight")
|
|
||||||
matched_diffusers_keys.append(f"te1_text_model.encoder.layers.{number}.self_attn.v_proj.weight")
|
|
||||||
# make diffusers convertable_dict
|
|
||||||
diffusers_state_dict[f"te1_text_model.encoder.layers.{number}.self_attn.MERGED.weight"] = new_val
|
|
||||||
|
|
||||||
# add operator
|
|
||||||
ldm_operator_map[ldm_key] = {
|
|
||||||
"cat": [
|
|
||||||
f"te1_text_model.encoder.layers.{number}.self_attn.q_proj.weight",
|
|
||||||
f"te1_text_model.encoder.layers.{number}.self_attn.k_proj.weight",
|
|
||||||
f"te1_text_model.encoder.layers.{number}.self_attn.v_proj.weight",
|
|
||||||
],
|
|
||||||
"target": f"te1_text_model.encoder.layers.{number}.self_attn.MERGED.weight"
|
|
||||||
}
|
|
||||||
|
|
||||||
|
if args.sdxl or args.sd2:
|
||||||
if "conditioner.embedders.1.model.text_projection" in ldm_dict_keys:
|
if "conditioner.embedders.1.model.text_projection" in ldm_dict_keys:
|
||||||
# d_model = int(checkpoint[prefix + "text_projection"].shape[0]))
|
# d_model = int(checkpoint[prefix + "text_projection"].shape[0]))
|
||||||
d_model = int(ldm_state_dict["conditioner.embedders.1.model.text_projection"].shape[0])
|
d_model = int(ldm_state_dict["conditioner.embedders.1.model.text_projection"].shape[0])
|
||||||
else:
|
else:
|
||||||
d_model = 1024
|
d_model = 1024
|
||||||
|
|
||||||
|
# do pre known merging
|
||||||
|
for ldm_key in ldm_dict_keys:
|
||||||
|
try:
|
||||||
|
match = re.match(proj_pattern_weight, ldm_key)
|
||||||
|
if match:
|
||||||
|
number = int(match.group(1))
|
||||||
|
new_val = torch.cat([
|
||||||
|
diffusers_state_dict[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.q_proj.weight"],
|
||||||
|
diffusers_state_dict[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.k_proj.weight"],
|
||||||
|
diffusers_state_dict[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.v_proj.weight"],
|
||||||
|
], dim=0)
|
||||||
|
# add to matched so we dont check them
|
||||||
|
matched_diffusers_keys.append(
|
||||||
|
f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.q_proj.weight")
|
||||||
|
matched_diffusers_keys.append(
|
||||||
|
f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.k_proj.weight")
|
||||||
|
matched_diffusers_keys.append(
|
||||||
|
f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.v_proj.weight")
|
||||||
|
# make diffusers convertable_dict
|
||||||
|
diffusers_state_dict[
|
||||||
|
f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.MERGED.weight"] = new_val
|
||||||
|
|
||||||
|
# add operator
|
||||||
|
ldm_operator_map[ldm_key] = {
|
||||||
|
"cat": [
|
||||||
|
f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.q_proj.weight",
|
||||||
|
f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.k_proj.weight",
|
||||||
|
f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.v_proj.weight",
|
||||||
|
],
|
||||||
|
"target": f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.MERGED.weight"
|
||||||
|
}
|
||||||
|
|
||||||
# text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :]
|
# text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :]
|
||||||
# text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model: d_model * 2, :]
|
# text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model: d_model * 2, :]
|
||||||
# text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2:, :]
|
# text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2:, :]
|
||||||
|
|
||||||
# add diffusers operators
|
# add diffusers operators
|
||||||
diffusers_operator_map[f"te1_text_model.encoder.layers.{number}.self_attn.q_proj.weight"] = {
|
diffusers_operator_map[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.q_proj.weight"] = {
|
||||||
"slice": [
|
"slice": [
|
||||||
f"conditioner.embedders.1.model.transformer.resblocks.{number}.attn.in_proj_weight",
|
f"{ldm_res_block_prefix}.{number}.attn.in_proj_weight",
|
||||||
f"0:{d_model}, :"
|
f"0:{d_model}, :"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
diffusers_operator_map[f"te1_text_model.encoder.layers.{number}.self_attn.k_proj.weight"] = {
|
diffusers_operator_map[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.k_proj.weight"] = {
|
||||||
"slice": [
|
"slice": [
|
||||||
f"conditioner.embedders.1.model.transformer.resblocks.{number}.attn.in_proj_weight",
|
f"{ldm_res_block_prefix}.{number}.attn.in_proj_weight",
|
||||||
f"{d_model}:{d_model * 2}, :"
|
f"{d_model}:{d_model * 2}, :"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
diffusers_operator_map[f"te1_text_model.encoder.layers.{number}.self_attn.v_proj.weight"] = {
|
diffusers_operator_map[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.v_proj.weight"] = {
|
||||||
"slice": [
|
"slice": [
|
||||||
f"conditioner.embedders.1.model.transformer.resblocks.{number}.attn.in_proj_weight",
|
f"{ldm_res_block_prefix}.{number}.attn.in_proj_weight",
|
||||||
f"{d_model * 2}:, :"
|
f"{d_model * 2}:, :"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
pattern = r"conditioner\.embedders\.1\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_bias"
|
match = re.match(proj_pattern_bias, ldm_key)
|
||||||
match = re.match(pattern, ldm_key)
|
|
||||||
if match:
|
if match:
|
||||||
number = int(match.group(1))
|
number = int(match.group(1))
|
||||||
new_val = torch.cat([
|
new_val = torch.cat([
|
||||||
diffusers_state_dict[f"te1_text_model.encoder.layers.{number}.self_attn.q_proj.bias"],
|
diffusers_state_dict[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.q_proj.bias"],
|
||||||
diffusers_state_dict[f"te1_text_model.encoder.layers.{number}.self_attn.k_proj.bias"],
|
diffusers_state_dict[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.k_proj.bias"],
|
||||||
diffusers_state_dict[f"te1_text_model.encoder.layers.{number}.self_attn.v_proj.bias"],
|
diffusers_state_dict[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.v_proj.bias"],
|
||||||
], dim=0)
|
], dim=0)
|
||||||
# add to matched so we dont check them
|
# add to matched so we dont check them
|
||||||
matched_diffusers_keys.append(f"te1_text_model.encoder.layers.{number}.self_attn.q_proj.bias")
|
matched_diffusers_keys.append(f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.q_proj.bias")
|
||||||
matched_diffusers_keys.append(f"te1_text_model.encoder.layers.{number}.self_attn.k_proj.bias")
|
matched_diffusers_keys.append(f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.k_proj.bias")
|
||||||
matched_diffusers_keys.append(f"te1_text_model.encoder.layers.{number}.self_attn.v_proj.bias")
|
matched_diffusers_keys.append(f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.v_proj.bias")
|
||||||
# make diffusers convertable_dict
|
# make diffusers convertable_dict
|
||||||
diffusers_state_dict[f"te1_text_model.encoder.layers.{number}.self_attn.MERGED.bias"] = new_val
|
diffusers_state_dict[
|
||||||
|
f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.MERGED.bias"] = new_val
|
||||||
|
|
||||||
# add operator
|
# add operator
|
||||||
ldm_operator_map[ldm_key] = {
|
ldm_operator_map[ldm_key] = {
|
||||||
"cat": [
|
"cat": [
|
||||||
f"te1_text_model.encoder.layers.{number}.self_attn.q_proj.bias",
|
f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.q_proj.bias",
|
||||||
f"te1_text_model.encoder.layers.{number}.self_attn.k_proj.bias",
|
f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.k_proj.bias",
|
||||||
f"te1_text_model.encoder.layers.{number}.self_attn.v_proj.bias",
|
f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.v_proj.bias",
|
||||||
],
|
],
|
||||||
"target": f"te1_text_model.encoder.layers.{number}.self_attn.MERGED.bias"
|
# "target": f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.MERGED.bias"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# add diffusers operators
|
||||||
|
diffusers_operator_map[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.q_proj.bias"] = {
|
||||||
|
"slice": [
|
||||||
|
f"{ldm_res_block_prefix}.{number}.attn.in_proj_bias",
|
||||||
|
f"0:{d_model}, :"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
diffusers_operator_map[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.k_proj.bias"] = {
|
||||||
|
"slice": [
|
||||||
|
f"{ldm_res_block_prefix}.{number}.attn.in_proj_bias",
|
||||||
|
f"{d_model}:{d_model * 2}, :"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
diffusers_operator_map[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.v_proj.bias"] = {
|
||||||
|
"slice": [
|
||||||
|
f"{ldm_res_block_prefix}.{number}.attn.in_proj_bias",
|
||||||
|
f"{d_model * 2}:, :"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error on key {ldm_key}")
|
||||||
|
print(e)
|
||||||
|
|
||||||
# update keys
|
# update keys
|
||||||
diffusers_dict_keys = list(diffusers_state_dict.keys())
|
diffusers_dict_keys = list(diffusers_state_dict.keys())
|
||||||
|
|
||||||
@@ -275,14 +324,10 @@ if has_unmatched_keys:
|
|||||||
weight = ldm_state_dict[key]
|
weight = ldm_state_dict[key]
|
||||||
weight_min = weight.min().item()
|
weight_min = weight.min().item()
|
||||||
weight_max = weight.max().item()
|
weight_max = weight.max().item()
|
||||||
weight_mean = weight.mean().item()
|
|
||||||
weight_std = weight.std().item()
|
|
||||||
unmatched_obj['ldm'][key] = {
|
unmatched_obj['ldm'][key] = {
|
||||||
'shape': weight.shape,
|
'shape': weight.shape,
|
||||||
"min": weight_min,
|
"min": weight_min,
|
||||||
"max": weight_max,
|
"max": weight_max,
|
||||||
"mean": weight_mean,
|
|
||||||
"std": weight_std,
|
|
||||||
}
|
}
|
||||||
del weight
|
del weight
|
||||||
flush()
|
flush()
|
||||||
@@ -292,14 +337,10 @@ if has_unmatched_keys:
|
|||||||
weight = diffusers_state_dict[key]
|
weight = diffusers_state_dict[key]
|
||||||
weight_min = weight.min().item()
|
weight_min = weight.min().item()
|
||||||
weight_max = weight.max().item()
|
weight_max = weight.max().item()
|
||||||
weight_mean = weight.mean().item()
|
|
||||||
weight_std = weight.std().item()
|
|
||||||
unmatched_obj['diffusers'][key] = {
|
unmatched_obj['diffusers'][key] = {
|
||||||
"shape": weight.shape,
|
"shape": weight.shape,
|
||||||
"min": weight_min,
|
"min": weight_min,
|
||||||
"max": weight_max,
|
"max": weight_max,
|
||||||
"mean": weight_mean,
|
|
||||||
"std": weight_std,
|
|
||||||
}
|
}
|
||||||
del weight
|
del weight
|
||||||
flush()
|
flush()
|
||||||
@@ -318,7 +359,6 @@ for key in unmatched_ldm_keys:
|
|||||||
save_file(remaining_ldm_values, os.path.join(KEYMAPS_FOLDER, f'{name}_ldm_base.safetensors'))
|
save_file(remaining_ldm_values, os.path.join(KEYMAPS_FOLDER, f'{name}_ldm_base.safetensors'))
|
||||||
print(f'Saved remaining ldm values to {os.path.join(KEYMAPS_FOLDER, f"{name}_ldm_base.safetensors")}')
|
print(f'Saved remaining ldm values to {os.path.join(KEYMAPS_FOLDER, f"{name}_ldm_base.safetensors")}')
|
||||||
|
|
||||||
|
|
||||||
dest_path = os.path.join(KEYMAPS_FOLDER, f'{name}.json')
|
dest_path = os.path.join(KEYMAPS_FOLDER, f'{name}.json')
|
||||||
save_obj = OrderedDict()
|
save_obj = OrderedDict()
|
||||||
save_obj["ldm_diffusers_keymap"] = ldm_diffusers_keymap
|
save_obj["ldm_diffusers_keymap"] = ldm_diffusers_keymap
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import json
|
|||||||
|
|
||||||
from toolkit.config_modules import ModelConfig
|
from toolkit.config_modules import ModelConfig
|
||||||
from toolkit.paths import KEYMAPS_ROOT
|
from toolkit.paths import KEYMAPS_ROOT
|
||||||
from toolkit.saving import convert_state_dict_to_ldm_with_mapping
|
from toolkit.saving import convert_state_dict_to_ldm_with_mapping, get_ldm_state_dict_from_diffusers
|
||||||
from toolkit.stable_diffusion_model import StableDiffusion
|
from toolkit.stable_diffusion_model import StableDiffusion
|
||||||
|
|
||||||
# this was just used to match the vae keys to the diffusers keys
|
# this was just used to match the vae keys to the diffusers keys
|
||||||
@@ -39,6 +39,12 @@ parser.add_argument(
|
|||||||
help='Is the model an XL model'
|
help='Is the model an XL model'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--is_v2',
|
||||||
|
action='store_true',
|
||||||
|
help='Is the model a v2 model'
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
find_matches = False
|
find_matches = False
|
||||||
@@ -58,19 +64,20 @@ sd = StableDiffusion(
|
|||||||
)
|
)
|
||||||
sd.load_model()
|
sd.load_model()
|
||||||
|
|
||||||
if not args.is_xl:
|
|
||||||
# not supported yet
|
|
||||||
raise NotImplementedError("Only SDXL is supported at this time with this method")
|
|
||||||
# load our base
|
# load our base
|
||||||
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sdxl_ldm_base.safetensors')
|
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sdxl_ldm_base.safetensors')
|
||||||
mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sdxl.json')
|
mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sdxl.json')
|
||||||
|
|
||||||
print("Converting model back to LDM")
|
print("Converting model back to LDM")
|
||||||
|
version_string = '1'
|
||||||
|
if args.is_v2:
|
||||||
|
version_string = '2'
|
||||||
|
if args.is_xl:
|
||||||
|
version_string = 'sdxl'
|
||||||
# convert the state dict
|
# convert the state dict
|
||||||
state_dict_file_2 = convert_state_dict_to_ldm_with_mapping(
|
state_dict_file_2 = get_ldm_state_dict_from_diffusers(
|
||||||
sd.state_dict(),
|
sd.state_dict(),
|
||||||
mapping_path,
|
version_string,
|
||||||
base_path,
|
|
||||||
device='cpu',
|
device='cpu',
|
||||||
dtype=dtype
|
dtype=dtype
|
||||||
)
|
)
|
||||||
|
|||||||
1962
toolkit/keymaps/stable_diffusion_sd1.json
Normal file
1962
toolkit/keymaps/stable_diffusion_sd1.json
Normal file
File diff suppressed because it is too large
Load Diff
BIN
toolkit/keymaps/stable_diffusion_sd1_ldm_base.safetensors
Normal file
BIN
toolkit/keymaps/stable_diffusion_sd1_ldm_base.safetensors
Normal file
Binary file not shown.
2773
toolkit/keymaps/stable_diffusion_sd2.json
Normal file
2773
toolkit/keymaps/stable_diffusion_sd2.json
Normal file
File diff suppressed because it is too large
Load Diff
BIN
toolkit/keymaps/stable_diffusion_sd2_ldm_base.safetensors
Normal file
BIN
toolkit/keymaps/stable_diffusion_sd2_ldm_base.safetensors
Normal file
Binary file not shown.
200
toolkit/keymaps/stable_diffusion_sd2_unmatched.json
Normal file
200
toolkit/keymaps/stable_diffusion_sd2_unmatched.json
Normal file
@@ -0,0 +1,200 @@
|
|||||||
|
{
|
||||||
|
"ldm": {
|
||||||
|
"alphas_cumprod": {
|
||||||
|
"shape": [
|
||||||
|
1000
|
||||||
|
],
|
||||||
|
"min": 0.00466156005859375,
|
||||||
|
"max": 0.9990234375
|
||||||
|
},
|
||||||
|
"alphas_cumprod_prev": {
|
||||||
|
"shape": [
|
||||||
|
1000
|
||||||
|
],
|
||||||
|
"min": 0.0047149658203125,
|
||||||
|
"max": 1.0
|
||||||
|
},
|
||||||
|
"betas": {
|
||||||
|
"shape": [
|
||||||
|
1000
|
||||||
|
],
|
||||||
|
"min": 0.0008502006530761719,
|
||||||
|
"max": 0.01200103759765625
|
||||||
|
},
|
||||||
|
"cond_stage_model.model.logit_scale": {
|
||||||
|
"shape": [],
|
||||||
|
"min": 4.60546875,
|
||||||
|
"max": 4.60546875
|
||||||
|
},
|
||||||
|
"cond_stage_model.model.text_projection": {
|
||||||
|
"shape": [
|
||||||
|
1024,
|
||||||
|
1024
|
||||||
|
],
|
||||||
|
"min": -0.109130859375,
|
||||||
|
"max": 0.09271240234375
|
||||||
|
},
|
||||||
|
"cond_stage_model.model.transformer.resblocks.23.attn.in_proj_bias": {
|
||||||
|
"shape": [
|
||||||
|
3072
|
||||||
|
],
|
||||||
|
"min": -2.525390625,
|
||||||
|
"max": 2.591796875
|
||||||
|
},
|
||||||
|
"cond_stage_model.model.transformer.resblocks.23.attn.in_proj_weight": {
|
||||||
|
"shape": [
|
||||||
|
3072,
|
||||||
|
1024
|
||||||
|
],
|
||||||
|
"min": -0.12261962890625,
|
||||||
|
"max": 0.1258544921875
|
||||||
|
},
|
||||||
|
"cond_stage_model.model.transformer.resblocks.23.attn.out_proj.bias": {
|
||||||
|
"shape": [
|
||||||
|
1024
|
||||||
|
],
|
||||||
|
"min": -0.422607421875,
|
||||||
|
"max": 1.17578125
|
||||||
|
},
|
||||||
|
"cond_stage_model.model.transformer.resblocks.23.attn.out_proj.weight": {
|
||||||
|
"shape": [
|
||||||
|
1024,
|
||||||
|
1024
|
||||||
|
],
|
||||||
|
"min": -0.0738525390625,
|
||||||
|
"max": 0.08673095703125
|
||||||
|
},
|
||||||
|
"cond_stage_model.model.transformer.resblocks.23.ln_1.bias": {
|
||||||
|
"shape": [
|
||||||
|
1024
|
||||||
|
],
|
||||||
|
"min": -3.392578125,
|
||||||
|
"max": 0.90625
|
||||||
|
},
|
||||||
|
"cond_stage_model.model.transformer.resblocks.23.ln_1.weight": {
|
||||||
|
"shape": [
|
||||||
|
1024
|
||||||
|
],
|
||||||
|
"min": 0.379638671875,
|
||||||
|
"max": 2.02734375
|
||||||
|
},
|
||||||
|
"cond_stage_model.model.transformer.resblocks.23.ln_2.bias": {
|
||||||
|
"shape": [
|
||||||
|
1024
|
||||||
|
],
|
||||||
|
"min": -0.833984375,
|
||||||
|
"max": 2.525390625
|
||||||
|
},
|
||||||
|
"cond_stage_model.model.transformer.resblocks.23.ln_2.weight": {
|
||||||
|
"shape": [
|
||||||
|
1024
|
||||||
|
],
|
||||||
|
"min": 1.17578125,
|
||||||
|
"max": 2.037109375
|
||||||
|
},
|
||||||
|
"cond_stage_model.model.transformer.resblocks.23.mlp.c_fc.bias": {
|
||||||
|
"shape": [
|
||||||
|
4096
|
||||||
|
],
|
||||||
|
"min": -1.619140625,
|
||||||
|
"max": 0.5595703125
|
||||||
|
},
|
||||||
|
"cond_stage_model.model.transformer.resblocks.23.mlp.c_fc.weight": {
|
||||||
|
"shape": [
|
||||||
|
4096,
|
||||||
|
1024
|
||||||
|
],
|
||||||
|
"min": -0.08953857421875,
|
||||||
|
"max": 0.13232421875
|
||||||
|
},
|
||||||
|
"cond_stage_model.model.transformer.resblocks.23.mlp.c_proj.bias": {
|
||||||
|
"shape": [
|
||||||
|
1024
|
||||||
|
],
|
||||||
|
"min": -1.8662109375,
|
||||||
|
"max": 0.74658203125
|
||||||
|
},
|
||||||
|
"cond_stage_model.model.transformer.resblocks.23.mlp.c_proj.weight": {
|
||||||
|
"shape": [
|
||||||
|
1024,
|
||||||
|
4096
|
||||||
|
],
|
||||||
|
"min": -0.12939453125,
|
||||||
|
"max": 0.1009521484375
|
||||||
|
},
|
||||||
|
"log_one_minus_alphas_cumprod": {
|
||||||
|
"shape": [
|
||||||
|
1000
|
||||||
|
],
|
||||||
|
"min": -7.0703125,
|
||||||
|
"max": -0.004669189453125
|
||||||
|
},
|
||||||
|
"model_ema.decay": {
|
||||||
|
"shape": [],
|
||||||
|
"min": 1.0,
|
||||||
|
"max": 1.0
|
||||||
|
},
|
||||||
|
"model_ema.num_updates": {
|
||||||
|
"shape": [],
|
||||||
|
"min": 219996,
|
||||||
|
"max": 219996
|
||||||
|
},
|
||||||
|
"posterior_log_variance_clipped": {
|
||||||
|
"shape": [
|
||||||
|
1000
|
||||||
|
],
|
||||||
|
"min": -46.0625,
|
||||||
|
"max": -4.421875
|
||||||
|
},
|
||||||
|
"posterior_mean_coef1": {
|
||||||
|
"shape": [
|
||||||
|
1000
|
||||||
|
],
|
||||||
|
"min": 0.000827789306640625,
|
||||||
|
"max": 1.0
|
||||||
|
},
|
||||||
|
"posterior_mean_coef2": {
|
||||||
|
"shape": [
|
||||||
|
1000
|
||||||
|
],
|
||||||
|
"min": 0.0,
|
||||||
|
"max": 0.99560546875
|
||||||
|
},
|
||||||
|
"posterior_variance": {
|
||||||
|
"shape": [
|
||||||
|
1000
|
||||||
|
],
|
||||||
|
"min": 0.0,
|
||||||
|
"max": 0.01200103759765625
|
||||||
|
},
|
||||||
|
"sqrt_alphas_cumprod": {
|
||||||
|
"shape": [
|
||||||
|
1000
|
||||||
|
],
|
||||||
|
"min": 0.0682373046875,
|
||||||
|
"max": 0.99951171875
|
||||||
|
},
|
||||||
|
"sqrt_one_minus_alphas_cumprod": {
|
||||||
|
"shape": [
|
||||||
|
1000
|
||||||
|
],
|
||||||
|
"min": 0.0291595458984375,
|
||||||
|
"max": 0.99755859375
|
||||||
|
},
|
||||||
|
"sqrt_recip_alphas_cumprod": {
|
||||||
|
"shape": [
|
||||||
|
1000
|
||||||
|
],
|
||||||
|
"min": 1.0,
|
||||||
|
"max": 14.6484375
|
||||||
|
},
|
||||||
|
"sqrt_recipm1_alphas_cumprod": {
|
||||||
|
"shape": [
|
||||||
|
1000
|
||||||
|
],
|
||||||
|
"min": 0.0291595458984375,
|
||||||
|
"max": 14.6171875
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"diffusers": {}
|
||||||
|
}
|
||||||
@@ -71,6 +71,35 @@ def convert_state_dict_to_ldm_with_mapping(
|
|||||||
return converted_state_dict
|
return converted_state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def get_ldm_state_dict_from_diffusers(
|
||||||
|
state_dict: 'OrderedDict',
|
||||||
|
sd_version: Literal['1', '2', 'sdxl'] = '2',
|
||||||
|
device='cpu',
|
||||||
|
dtype=get_torch_dtype('fp32'),
|
||||||
|
):
|
||||||
|
if sd_version == '1':
|
||||||
|
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sd1_ldm_base.safetensors')
|
||||||
|
mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sd1.json')
|
||||||
|
elif sd_version == '2':
|
||||||
|
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sd2_ldm_base.safetensors')
|
||||||
|
mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sd2.json')
|
||||||
|
elif sd_version == 'sdxl':
|
||||||
|
# load our base
|
||||||
|
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sdxl_ldm_base.safetensors')
|
||||||
|
mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sdxl.json')
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid sd_version {sd_version}")
|
||||||
|
|
||||||
|
# convert the state dict
|
||||||
|
return convert_state_dict_to_ldm_with_mapping(
|
||||||
|
state_dict,
|
||||||
|
mapping_path,
|
||||||
|
base_path,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def save_ldm_model_from_diffusers(
|
def save_ldm_model_from_diffusers(
|
||||||
sd: 'StableDiffusion',
|
sd: 'StableDiffusion',
|
||||||
output_file: str,
|
output_file: str,
|
||||||
@@ -78,21 +107,13 @@ def save_ldm_model_from_diffusers(
|
|||||||
save_dtype=get_torch_dtype('fp16'),
|
save_dtype=get_torch_dtype('fp16'),
|
||||||
sd_version: Literal['1', '2', 'sdxl'] = '2'
|
sd_version: Literal['1', '2', 'sdxl'] = '2'
|
||||||
):
|
):
|
||||||
if sd_version != 'sdxl':
|
converted_state_dict = get_ldm_state_dict_from_diffusers(
|
||||||
# not supported yet
|
|
||||||
raise NotImplementedError("Only SDXL is supported at this time with this method")
|
|
||||||
# load our base
|
|
||||||
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sdxl_ldm_base.safetensors')
|
|
||||||
mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sdxl.json')
|
|
||||||
|
|
||||||
# convert the state dict
|
|
||||||
converted_state_dict = convert_state_dict_to_ldm_with_mapping(
|
|
||||||
sd.state_dict(),
|
sd.state_dict(),
|
||||||
mapping_path,
|
sd_version,
|
||||||
base_path,
|
|
||||||
device='cpu',
|
device='cpu',
|
||||||
dtype=save_dtype
|
dtype=save_dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
# make sure parent folder exists
|
# make sure parent folder exists
|
||||||
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||||
save_file(converted_state_dict, output_file, metadata=meta)
|
save_file(converted_state_dict, output_file, metadata=meta)
|
||||||
|
|||||||
@@ -752,61 +752,28 @@ class StableDiffusion:
|
|||||||
for i, encoder in enumerate(self.text_encoder):
|
for i, encoder in enumerate(self.text_encoder):
|
||||||
for k, v in encoder.state_dict().items():
|
for k, v in encoder.state_dict().items():
|
||||||
new_key = k if k.startswith(
|
new_key = k if k.startswith(
|
||||||
f"{SD_PREFIX_TEXT_ENCODER}{i}") else f"{SD_PREFIX_TEXT_ENCODER}{i}_{k}"
|
f"{SD_PREFIX_TEXT_ENCODER}{i}_") else f"{SD_PREFIX_TEXT_ENCODER}{i}_{k}"
|
||||||
state_dict[new_key] = v
|
state_dict[new_key] = v
|
||||||
else:
|
else:
|
||||||
for k, v in self.text_encoder.state_dict().items():
|
for k, v in self.text_encoder.state_dict().items():
|
||||||
new_key = k if k.startswith(f"{SD_PREFIX_TEXT_ENCODER}") else f"{SD_PREFIX_TEXT_ENCODER}_{k}"
|
new_key = k if k.startswith(f"{SD_PREFIX_TEXT_ENCODER}_") else f"{SD_PREFIX_TEXT_ENCODER}_{k}"
|
||||||
state_dict[new_key] = v
|
state_dict[new_key] = v
|
||||||
if unet:
|
if unet:
|
||||||
for k, v in self.unet.state_dict().items():
|
for k, v in self.unet.state_dict().items():
|
||||||
new_key = k if k.startswith(f"{SD_PREFIX_UNET}") else f"{SD_PREFIX_UNET}_{k}"
|
new_key = k if k.startswith(f"{SD_PREFIX_UNET}_") else f"{SD_PREFIX_UNET}_{k}"
|
||||||
state_dict[new_key] = v
|
state_dict[new_key] = v
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None):
|
def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None):
|
||||||
state_dict = {}
|
version_string = '1'
|
||||||
# prepare metadata
|
if self.is_v2:
|
||||||
meta = get_meta_for_safetensors(meta)
|
version_string = '2'
|
||||||
|
|
||||||
def update_sd(prefix, sd):
|
|
||||||
for k, v in sd.items():
|
|
||||||
key = prefix + k
|
|
||||||
v = v.detach().clone()
|
|
||||||
state_dict[key] = v.to("cpu", dtype=get_torch_dtype(save_dtype))
|
|
||||||
# make sure there are not nan values
|
|
||||||
if torch.isnan(state_dict[key]).any():
|
|
||||||
raise ValueError(f"NaN value in state dict: {key}")
|
|
||||||
|
|
||||||
# todo see what logit scale is
|
|
||||||
if self.is_xl:
|
if self.is_xl:
|
||||||
|
version_string = 'sdxl'
|
||||||
save_ldm_model_from_diffusers(
|
save_ldm_model_from_diffusers(
|
||||||
sd=self,
|
sd=self,
|
||||||
output_file=output_file,
|
output_file=output_file,
|
||||||
meta=meta,
|
meta=meta,
|
||||||
save_dtype=save_dtype,
|
save_dtype=save_dtype,
|
||||||
sd_version='sdxl',
|
sd_version=version_string,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
|
||||||
# Convert the UNet model
|
|
||||||
unet_state_dict = convert_unet_state_dict_to_sd(self.is_v2, self.unet.state_dict())
|
|
||||||
update_sd("model.diffusion_model.", unet_state_dict)
|
|
||||||
|
|
||||||
# Convert the text encoder model
|
|
||||||
if self.is_v2:
|
|
||||||
make_dummy = True
|
|
||||||
text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(self.text_encoder.state_dict(), make_dummy)
|
|
||||||
update_sd("cond_stage_model.model.", text_enc_dict)
|
|
||||||
else:
|
|
||||||
text_enc_dict = self.text_encoder.state_dict()
|
|
||||||
update_sd("cond_stage_model.transformer.", text_enc_dict)
|
|
||||||
|
|
||||||
# Convert the VAE
|
|
||||||
if self.vae is not None:
|
|
||||||
vae_dict = model_util.convert_vae_state_dict(self.vae.state_dict())
|
|
||||||
update_sd("first_stage_model.", vae_dict)
|
|
||||||
|
|
||||||
# make sure parent folder exists
|
|
||||||
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
|
||||||
save_file(state_dict, output_file, metadata=meta)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user