mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Added a converter back to ldm from diffusers for sdxl. Can finally get to training it properly
This commit is contained in:
@@ -95,7 +95,7 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
|
|||||||
if weight_jitter > 0.0:
|
if weight_jitter > 0.0:
|
||||||
jitter_list = random.uniform(-weight_jitter, weight_jitter)
|
jitter_list = random.uniform(-weight_jitter, weight_jitter)
|
||||||
network_pos_weight += jitter_list
|
network_pos_weight += jitter_list
|
||||||
network_neg_weight += jitter_list
|
network_neg_weight += (jitter_list * -1.0)
|
||||||
|
|
||||||
# if items in network_weight list are tensors, convert them to floats
|
# if items in network_weight list are tensors, convert them to floats
|
||||||
|
|
||||||
|
|||||||
@@ -248,7 +248,7 @@ class UltimateSliderTrainerProcess(BaseSDTrainProcess):
|
|||||||
if weight_jitter > 0.0:
|
if weight_jitter > 0.0:
|
||||||
jitter_list = random.uniform(-weight_jitter, weight_jitter)
|
jitter_list = random.uniform(-weight_jitter, weight_jitter)
|
||||||
network_pos_weight += jitter_list
|
network_pos_weight += jitter_list
|
||||||
network_neg_weight += jitter_list
|
network_neg_weight += (jitter_list * -1.0)
|
||||||
|
|
||||||
# if items in network_weight list are tensors, convert them to floats
|
# if items in network_weight list are tensors, convert them to floats
|
||||||
imgs: torch.Tensor = imgs.to(self.device_torch, dtype=dtype)
|
imgs: torch.Tensor = imgs.to(self.device_torch, dtype=dtype)
|
||||||
|
|||||||
332
testing/generate_weight_mappings.py
Normal file
332
testing/generate_weight_mappings.py
Normal file
@@ -0,0 +1,332 @@
|
|||||||
|
import argparse
|
||||||
|
import gc
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from diffusers.loaders import LoraLoaderMixin
|
||||||
|
from safetensors.torch import load_file, save_file
|
||||||
|
from collections import OrderedDict
|
||||||
|
import json
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from toolkit.config_modules import ModelConfig
|
||||||
|
from toolkit.stable_diffusion_model import StableDiffusion
|
||||||
|
|
||||||
|
KEYMAPS_FOLDER = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'toolkit', 'keymaps')
|
||||||
|
|
||||||
|
device = torch.device('cpu')
|
||||||
|
dtype = torch.float32
|
||||||
|
|
||||||
|
|
||||||
|
def flush():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
|
||||||
|
def get_reduced_shape(shape_tuple):
|
||||||
|
# iterate though shape anr remove 1s
|
||||||
|
new_shape = []
|
||||||
|
for dim in shape_tuple:
|
||||||
|
if dim != 1:
|
||||||
|
new_shape.append(dim)
|
||||||
|
return tuple(new_shape)
|
||||||
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
# require at lease one config file
|
||||||
|
parser.add_argument(
|
||||||
|
'file_1',
|
||||||
|
nargs='+',
|
||||||
|
type=str,
|
||||||
|
help='Path to first safe tensor file'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument('--name', type=str, default='stable_diffusion', help='name for mapping to make')
|
||||||
|
parser.add_argument('--sdxl', action='store_true', help='is sdxl model')
|
||||||
|
parser.add_argument('--sd2', action='store_true', help='is sd 2 model')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
file_path = args.file_1[0]
|
||||||
|
|
||||||
|
find_matches = False
|
||||||
|
|
||||||
|
print(f'Loading diffusers model')
|
||||||
|
|
||||||
|
diffusers_model_config = ModelConfig(
|
||||||
|
name_or_path=file_path,
|
||||||
|
is_xl=args.sdxl,
|
||||||
|
is_v2=args.sd2,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
diffusers_sd = StableDiffusion(
|
||||||
|
model_config=diffusers_model_config,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
diffusers_sd.load_model()
|
||||||
|
# delete things we dont need
|
||||||
|
del diffusers_sd.tokenizer
|
||||||
|
flush()
|
||||||
|
|
||||||
|
print(f'Loading ldm model')
|
||||||
|
diffusers_state_dict = diffusers_sd.state_dict()
|
||||||
|
diffusers_dict_keys = list(diffusers_state_dict.keys())
|
||||||
|
|
||||||
|
ldm_state_dict = load_file(file_path)
|
||||||
|
ldm_dict_keys = list(ldm_state_dict.keys())
|
||||||
|
|
||||||
|
ldm_diffusers_keymap = OrderedDict()
|
||||||
|
ldm_diffusers_shape_map = OrderedDict()
|
||||||
|
ldm_operator_map = OrderedDict()
|
||||||
|
diffusers_operator_map = OrderedDict()
|
||||||
|
|
||||||
|
total_keys = len(ldm_dict_keys)
|
||||||
|
|
||||||
|
matched_ldm_keys = []
|
||||||
|
matched_diffusers_keys = []
|
||||||
|
|
||||||
|
error_margin = 1e-4
|
||||||
|
|
||||||
|
if args.sdxl:
|
||||||
|
# do pre known merging
|
||||||
|
for ldm_key in ldm_dict_keys:
|
||||||
|
pattern = r"conditioner\.embedders\.1\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_weight"
|
||||||
|
match = re.match(pattern, ldm_key)
|
||||||
|
if match:
|
||||||
|
number = int(match.group(1))
|
||||||
|
new_val = torch.cat([
|
||||||
|
diffusers_state_dict[f"te1_text_model.encoder.layers.{number}.self_attn.q_proj.weight"],
|
||||||
|
diffusers_state_dict[f"te1_text_model.encoder.layers.{number}.self_attn.k_proj.weight"],
|
||||||
|
diffusers_state_dict[f"te1_text_model.encoder.layers.{number}.self_attn.v_proj.weight"],
|
||||||
|
], dim=0)
|
||||||
|
# add to matched so we dont check them
|
||||||
|
matched_diffusers_keys.append(f"te1_text_model.encoder.layers.{number}.self_attn.q_proj.weight")
|
||||||
|
matched_diffusers_keys.append(f"te1_text_model.encoder.layers.{number}.self_attn.k_proj.weight")
|
||||||
|
matched_diffusers_keys.append(f"te1_text_model.encoder.layers.{number}.self_attn.v_proj.weight")
|
||||||
|
# make diffusers convertable_dict
|
||||||
|
diffusers_state_dict[f"te1_text_model.encoder.layers.{number}.self_attn.MERGED.weight"] = new_val
|
||||||
|
|
||||||
|
# add operator
|
||||||
|
ldm_operator_map[ldm_key] = {
|
||||||
|
"cat": [
|
||||||
|
f"te1_text_model.encoder.layers.{number}.self_attn.q_proj.weight",
|
||||||
|
f"te1_text_model.encoder.layers.{number}.self_attn.k_proj.weight",
|
||||||
|
f"te1_text_model.encoder.layers.{number}.self_attn.v_proj.weight",
|
||||||
|
],
|
||||||
|
"target": f"te1_text_model.encoder.layers.{number}.self_attn.MERGED.weight"
|
||||||
|
}
|
||||||
|
|
||||||
|
if "conditioner.embedders.1.model.text_projection" in ldm_dict_keys:
|
||||||
|
# d_model = int(checkpoint[prefix + "text_projection"].shape[0]))
|
||||||
|
d_model = int(ldm_state_dict["conditioner.embedders.1.model.text_projection"].shape[0])
|
||||||
|
else:
|
||||||
|
d_model = 1024
|
||||||
|
|
||||||
|
# text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :]
|
||||||
|
# text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model: d_model * 2, :]
|
||||||
|
# text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2:, :]
|
||||||
|
|
||||||
|
# add diffusers operators
|
||||||
|
diffusers_operator_map[f"te1_text_model.encoder.layers.{number}.self_attn.q_proj.weight"] = {
|
||||||
|
"slice": [
|
||||||
|
f"conditioner.embedders.1.model.transformer.resblocks.{number}.attn.in_proj_weight",
|
||||||
|
f"0:{d_model}, :"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
diffusers_operator_map[f"te1_text_model.encoder.layers.{number}.self_attn.k_proj.weight"] = {
|
||||||
|
"slice": [
|
||||||
|
f"conditioner.embedders.1.model.transformer.resblocks.{number}.attn.in_proj_weight",
|
||||||
|
f"{d_model}:{d_model * 2}, :"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
diffusers_operator_map[f"te1_text_model.encoder.layers.{number}.self_attn.v_proj.weight"] = {
|
||||||
|
"slice": [
|
||||||
|
f"conditioner.embedders.1.model.transformer.resblocks.{number}.attn.in_proj_weight",
|
||||||
|
f"{d_model * 2}:, :"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
pattern = r"conditioner\.embedders\.1\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_bias"
|
||||||
|
match = re.match(pattern, ldm_key)
|
||||||
|
if match:
|
||||||
|
number = int(match.group(1))
|
||||||
|
new_val = torch.cat([
|
||||||
|
diffusers_state_dict[f"te1_text_model.encoder.layers.{number}.self_attn.q_proj.bias"],
|
||||||
|
diffusers_state_dict[f"te1_text_model.encoder.layers.{number}.self_attn.k_proj.bias"],
|
||||||
|
diffusers_state_dict[f"te1_text_model.encoder.layers.{number}.self_attn.v_proj.bias"],
|
||||||
|
], dim=0)
|
||||||
|
# add to matched so we dont check them
|
||||||
|
matched_diffusers_keys.append(f"te1_text_model.encoder.layers.{number}.self_attn.q_proj.bias")
|
||||||
|
matched_diffusers_keys.append(f"te1_text_model.encoder.layers.{number}.self_attn.k_proj.bias")
|
||||||
|
matched_diffusers_keys.append(f"te1_text_model.encoder.layers.{number}.self_attn.v_proj.bias")
|
||||||
|
# make diffusers convertable_dict
|
||||||
|
diffusers_state_dict[f"te1_text_model.encoder.layers.{number}.self_attn.MERGED.bias"] = new_val
|
||||||
|
|
||||||
|
# add operator
|
||||||
|
ldm_operator_map[ldm_key] = {
|
||||||
|
"cat": [
|
||||||
|
f"te1_text_model.encoder.layers.{number}.self_attn.q_proj.bias",
|
||||||
|
f"te1_text_model.encoder.layers.{number}.self_attn.k_proj.bias",
|
||||||
|
f"te1_text_model.encoder.layers.{number}.self_attn.v_proj.bias",
|
||||||
|
],
|
||||||
|
"target": f"te1_text_model.encoder.layers.{number}.self_attn.MERGED.bias"
|
||||||
|
}
|
||||||
|
|
||||||
|
# update keys
|
||||||
|
diffusers_dict_keys = list(diffusers_state_dict.keys())
|
||||||
|
|
||||||
|
pbar = tqdm(ldm_dict_keys, desc='Matching ldm-diffusers keys', total=total_keys)
|
||||||
|
# run through all weights and check mse between them to find matches
|
||||||
|
for ldm_key in ldm_dict_keys:
|
||||||
|
ldm_shape_tuple = ldm_state_dict[ldm_key].shape
|
||||||
|
ldm_reduced_shape_tuple = get_reduced_shape(ldm_shape_tuple)
|
||||||
|
for diffusers_key in diffusers_dict_keys:
|
||||||
|
diffusers_shape_tuple = diffusers_state_dict[diffusers_key].shape
|
||||||
|
diffusers_reduced_shape_tuple = get_reduced_shape(diffusers_shape_tuple)
|
||||||
|
|
||||||
|
# That was easy. Same key
|
||||||
|
if ldm_key == diffusers_key:
|
||||||
|
ldm_diffusers_keymap[ldm_key] = diffusers_key
|
||||||
|
matched_ldm_keys.append(ldm_key)
|
||||||
|
matched_diffusers_keys.append(diffusers_key)
|
||||||
|
break
|
||||||
|
|
||||||
|
# if we already have this key mapped, skip it
|
||||||
|
if diffusers_key in matched_diffusers_keys:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# if reduced shapes do not match skip it
|
||||||
|
if ldm_reduced_shape_tuple != diffusers_reduced_shape_tuple:
|
||||||
|
continue
|
||||||
|
|
||||||
|
ldm_weight = ldm_state_dict[ldm_key]
|
||||||
|
did_reduce_ldm = False
|
||||||
|
diffusers_weight = diffusers_state_dict[diffusers_key]
|
||||||
|
did_reduce_diffusers = False
|
||||||
|
|
||||||
|
# reduce the shapes to match if they are not the same
|
||||||
|
if ldm_shape_tuple != ldm_reduced_shape_tuple:
|
||||||
|
ldm_weight = ldm_weight.view(ldm_reduced_shape_tuple)
|
||||||
|
did_reduce_ldm = True
|
||||||
|
|
||||||
|
if diffusers_shape_tuple != diffusers_reduced_shape_tuple:
|
||||||
|
diffusers_weight = diffusers_weight.view(diffusers_reduced_shape_tuple)
|
||||||
|
did_reduce_diffusers = True
|
||||||
|
|
||||||
|
# check to see if they match within a margin of error
|
||||||
|
mse = torch.nn.functional.mse_loss(ldm_weight, diffusers_weight)
|
||||||
|
if mse < error_margin:
|
||||||
|
ldm_diffusers_keymap[ldm_key] = diffusers_key
|
||||||
|
matched_ldm_keys.append(ldm_key)
|
||||||
|
matched_diffusers_keys.append(diffusers_key)
|
||||||
|
|
||||||
|
if did_reduce_ldm or did_reduce_diffusers:
|
||||||
|
ldm_diffusers_shape_map[ldm_key] = (ldm_shape_tuple, diffusers_shape_tuple)
|
||||||
|
if did_reduce_ldm:
|
||||||
|
del ldm_weight
|
||||||
|
if did_reduce_diffusers:
|
||||||
|
del diffusers_weight
|
||||||
|
flush()
|
||||||
|
|
||||||
|
break
|
||||||
|
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
pbar.close()
|
||||||
|
|
||||||
|
name = args.name
|
||||||
|
if args.sdxl:
|
||||||
|
name += '_sdxl'
|
||||||
|
elif args.sd2:
|
||||||
|
name += '_sd2'
|
||||||
|
else:
|
||||||
|
name += '_sd1'
|
||||||
|
|
||||||
|
# if len(matched_ldm_keys) != len(matched_diffusers_keys):
|
||||||
|
unmatched_ldm_keys = [x for x in ldm_dict_keys if x not in matched_ldm_keys]
|
||||||
|
unmatched_diffusers_keys = [x for x in diffusers_dict_keys if x not in matched_diffusers_keys]
|
||||||
|
# has unmatched keys
|
||||||
|
|
||||||
|
has_unmatched_keys = len(unmatched_ldm_keys) > 0 or len(unmatched_diffusers_keys) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def get_slices_from_string(s: str) -> tuple:
|
||||||
|
slice_strings = s.split(',')
|
||||||
|
slices = [eval(f"slice({component.strip()})") for component in slice_strings]
|
||||||
|
return tuple(slices)
|
||||||
|
|
||||||
|
|
||||||
|
if has_unmatched_keys:
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Found {len(unmatched_ldm_keys)} unmatched ldm keys and {len(unmatched_diffusers_keys)} unmatched diffusers keys")
|
||||||
|
|
||||||
|
unmatched_obj = OrderedDict()
|
||||||
|
unmatched_obj['ldm'] = OrderedDict()
|
||||||
|
unmatched_obj['diffusers'] = OrderedDict()
|
||||||
|
|
||||||
|
print(f"Gathering info on unmatched keys")
|
||||||
|
|
||||||
|
for key in tqdm(unmatched_ldm_keys, desc='Unmatched LDM keys'):
|
||||||
|
# get min, max, mean, std
|
||||||
|
weight = ldm_state_dict[key]
|
||||||
|
weight_min = weight.min().item()
|
||||||
|
weight_max = weight.max().item()
|
||||||
|
weight_mean = weight.mean().item()
|
||||||
|
weight_std = weight.std().item()
|
||||||
|
unmatched_obj['ldm'][key] = {
|
||||||
|
'shape': weight.shape,
|
||||||
|
"min": weight_min,
|
||||||
|
"max": weight_max,
|
||||||
|
"mean": weight_mean,
|
||||||
|
"std": weight_std,
|
||||||
|
}
|
||||||
|
del weight
|
||||||
|
flush()
|
||||||
|
|
||||||
|
for key in tqdm(unmatched_diffusers_keys, desc='Unmatched Diffusers keys'):
|
||||||
|
# get min, max, mean, std
|
||||||
|
weight = diffusers_state_dict[key]
|
||||||
|
weight_min = weight.min().item()
|
||||||
|
weight_max = weight.max().item()
|
||||||
|
weight_mean = weight.mean().item()
|
||||||
|
weight_std = weight.std().item()
|
||||||
|
unmatched_obj['diffusers'][key] = {
|
||||||
|
"shape": weight.shape,
|
||||||
|
"min": weight_min,
|
||||||
|
"max": weight_max,
|
||||||
|
"mean": weight_mean,
|
||||||
|
"std": weight_std,
|
||||||
|
}
|
||||||
|
del weight
|
||||||
|
flush()
|
||||||
|
|
||||||
|
unmatched_path = os.path.join(KEYMAPS_FOLDER, f'{name}_unmatched.json')
|
||||||
|
with open(unmatched_path, 'w') as f:
|
||||||
|
f.write(json.dumps(unmatched_obj, indent=4))
|
||||||
|
|
||||||
|
print(f'Saved unmatched keys to {unmatched_path}')
|
||||||
|
|
||||||
|
# save ldm remainders
|
||||||
|
remaining_ldm_values = OrderedDict()
|
||||||
|
for key in unmatched_ldm_keys:
|
||||||
|
remaining_ldm_values[key] = ldm_state_dict[key].detach().to('cpu', torch.float16)
|
||||||
|
|
||||||
|
save_file(remaining_ldm_values, os.path.join(KEYMAPS_FOLDER, f'{name}_ldm_base.safetensors'))
|
||||||
|
print(f'Saved remaining ldm values to {os.path.join(KEYMAPS_FOLDER, f"{name}_ldm_base.safetensors")}')
|
||||||
|
|
||||||
|
|
||||||
|
dest_path = os.path.join(KEYMAPS_FOLDER, f'{name}.json')
|
||||||
|
save_obj = OrderedDict()
|
||||||
|
save_obj["ldm_diffusers_keymap"] = ldm_diffusers_keymap
|
||||||
|
save_obj["ldm_diffusers_shape_map"] = ldm_diffusers_shape_map
|
||||||
|
save_obj["ldm_diffusers_operator_map"] = ldm_operator_map
|
||||||
|
save_obj["diffusers_ldm_operator_map"] = diffusers_operator_map
|
||||||
|
|
||||||
|
with open(dest_path, 'w') as f:
|
||||||
|
f.write(json.dumps(save_obj, indent=4))
|
||||||
|
|
||||||
|
print(f'Saved keymap to {dest_path}')
|
||||||
@@ -77,6 +77,7 @@ class ModelConfig:
|
|||||||
self.is_xl: bool = kwargs.get('is_xl', False)
|
self.is_xl: bool = kwargs.get('is_xl', False)
|
||||||
self.is_v_pred: bool = kwargs.get('is_v_pred', False)
|
self.is_v_pred: bool = kwargs.get('is_v_pred', False)
|
||||||
self.dtype: str = kwargs.get('dtype', 'float16')
|
self.dtype: str = kwargs.get('dtype', 'float16')
|
||||||
|
self.vae_path = kwargs.get('vae_path', None)
|
||||||
|
|
||||||
# only for SDXL models for now
|
# only for SDXL models for now
|
||||||
self.use_text_encoder_1: bool = kwargs.get('use_text_encoder_1', True)
|
self.use_text_encoder_1: bool = kwargs.get('use_text_encoder_1', True)
|
||||||
|
|||||||
3944
toolkit/keymaps/stable_diffusion_sdxl.json
Normal file
3944
toolkit/keymaps/stable_diffusion_sdxl.json
Normal file
File diff suppressed because it is too large
Load Diff
BIN
toolkit/keymaps/stable_diffusion_sdxl_ldm_base.safetensors
Normal file
BIN
toolkit/keymaps/stable_diffusion_sdxl_ldm_base.safetensors
Normal file
Binary file not shown.
43
toolkit/keymaps/stable_diffusion_sdxl_unmatched.json
Normal file
43
toolkit/keymaps/stable_diffusion_sdxl_unmatched.json
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
{
|
||||||
|
"ldm": {
|
||||||
|
"conditioner.embedders.0.transformer.text_model.embeddings.position_ids": {
|
||||||
|
"shape": [
|
||||||
|
1,
|
||||||
|
77
|
||||||
|
],
|
||||||
|
"min": 0.0,
|
||||||
|
"max": 76.0,
|
||||||
|
"mean": 38.0,
|
||||||
|
"std": 22.375
|
||||||
|
},
|
||||||
|
"conditioner.embedders.1.model.logit_scale": {
|
||||||
|
"shape": [],
|
||||||
|
"min": 4.60546875,
|
||||||
|
"max": 4.60546875,
|
||||||
|
"mean": 4.60546875,
|
||||||
|
"std": NaN
|
||||||
|
},
|
||||||
|
"conditioner.embedders.1.model.text_projection": {
|
||||||
|
"shape": [
|
||||||
|
1280,
|
||||||
|
1280
|
||||||
|
],
|
||||||
|
"min": -0.15966796875,
|
||||||
|
"max": 0.230712890625,
|
||||||
|
"mean": 0.0,
|
||||||
|
"std": 0.0181732177734375
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"diffusers": {
|
||||||
|
"te1_text_projection.weight": {
|
||||||
|
"shape": [
|
||||||
|
1280,
|
||||||
|
1280
|
||||||
|
],
|
||||||
|
"min": -0.15966796875,
|
||||||
|
"max": 0.230712890625,
|
||||||
|
"mean": 2.128152846125886e-05,
|
||||||
|
"std": 0.018169498071074486
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -4,6 +4,7 @@ TOOLKIT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|||||||
CONFIG_ROOT = os.path.join(TOOLKIT_ROOT, 'config')
|
CONFIG_ROOT = os.path.join(TOOLKIT_ROOT, 'config')
|
||||||
SD_SCRIPTS_ROOT = os.path.join(TOOLKIT_ROOT, "repositories", "sd-scripts")
|
SD_SCRIPTS_ROOT = os.path.join(TOOLKIT_ROOT, "repositories", "sd-scripts")
|
||||||
REPOS_ROOT = os.path.join(TOOLKIT_ROOT, "repositories")
|
REPOS_ROOT = os.path.join(TOOLKIT_ROOT, "repositories")
|
||||||
|
KEYMAPS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "keymaps")
|
||||||
|
|
||||||
# check if ENV variable is set
|
# check if ENV variable is set
|
||||||
if 'MODELS_PATH' in os.environ:
|
if 'MODELS_PATH' in os.environ:
|
||||||
|
|||||||
98
toolkit/saving.py
Normal file
98
toolkit/saving.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import TYPE_CHECKING, Literal, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from safetensors.torch import load_file, save_file
|
||||||
|
|
||||||
|
from toolkit.train_tools import get_torch_dtype
|
||||||
|
from toolkit.paths import KEYMAPS_ROOT
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from toolkit.stable_diffusion_model import StableDiffusion
|
||||||
|
|
||||||
|
|
||||||
|
def get_slices_from_string(s: str) -> tuple:
|
||||||
|
slice_strings = s.split(',')
|
||||||
|
slices = [eval(f"slice({component.strip()})") for component in slice_strings]
|
||||||
|
return tuple(slices)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_state_dict_to_ldm_with_mapping(
|
||||||
|
diffusers_state_dict: 'OrderedDict',
|
||||||
|
mapping_path: str,
|
||||||
|
base_path: Union[str, None] = None,
|
||||||
|
device: str = 'cpu',
|
||||||
|
dtype: torch.dtype = torch.float32
|
||||||
|
) -> 'OrderedDict':
|
||||||
|
converted_state_dict = OrderedDict()
|
||||||
|
|
||||||
|
# load mapping
|
||||||
|
with open(mapping_path, 'r') as f:
|
||||||
|
mapping = json.load(f, object_pairs_hook=OrderedDict)
|
||||||
|
|
||||||
|
ldm_diffusers_keymap = mapping['ldm_diffusers_keymap']
|
||||||
|
ldm_diffusers_shape_map = mapping['ldm_diffusers_shape_map']
|
||||||
|
ldm_diffusers_operator_map = mapping['ldm_diffusers_operator_map']
|
||||||
|
|
||||||
|
# load base if it exists
|
||||||
|
# the base just has come keys like timing ids and stuff diffusers doesn't have or they don't match
|
||||||
|
if base_path is not None:
|
||||||
|
converted_state_dict = load_file(base_path, device)
|
||||||
|
# convert to the right dtype
|
||||||
|
for key in converted_state_dict:
|
||||||
|
converted_state_dict[key] = converted_state_dict[key].to(device, dtype=dtype)
|
||||||
|
|
||||||
|
# process operators first
|
||||||
|
for ldm_key in ldm_diffusers_operator_map:
|
||||||
|
# if the key cat is in the ldm key, we need to process it
|
||||||
|
if 'cat' in ldm_key:
|
||||||
|
cat_list = []
|
||||||
|
for diffusers_key in ldm_diffusers_operator_map[ldm_key]['cat']:
|
||||||
|
cat_list.append(diffusers_state_dict[diffusers_key].detatch())
|
||||||
|
converted_state_dict[ldm_key] = torch.cat(cat_list, dim=0).to(device, dtype=dtype)
|
||||||
|
if 'slice' in ldm_key:
|
||||||
|
tensor_to_slice = diffusers_state_dict[ldm_diffusers_operator_map[ldm_key]['slice'][0]]
|
||||||
|
slice_text = diffusers_state_dict[ldm_diffusers_operator_map[ldm_key]['slice'][1]]
|
||||||
|
converted_state_dict[ldm_key] = tensor_to_slice[get_slices_from_string(slice_text)].detatch().to(device,
|
||||||
|
dtype=dtype)
|
||||||
|
|
||||||
|
# process the rest of the keys
|
||||||
|
for ldm_key in ldm_diffusers_keymap:
|
||||||
|
# if the key is in the ldm key, we need to process it
|
||||||
|
if ldm_diffusers_keymap[ldm_key] in diffusers_state_dict:
|
||||||
|
tensor = diffusers_state_dict[ldm_diffusers_keymap[ldm_key]].detach().to(device, dtype=dtype)
|
||||||
|
# see if we need to reshape
|
||||||
|
if ldm_key in ldm_diffusers_shape_map:
|
||||||
|
tensor = tensor.view(ldm_diffusers_shape_map[ldm_key][0])
|
||||||
|
converted_state_dict[ldm_key] = tensor
|
||||||
|
|
||||||
|
return converted_state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def save_ldm_model_from_diffusers(
|
||||||
|
sd: 'StableDiffusion',
|
||||||
|
output_file: str,
|
||||||
|
meta: 'OrderedDict',
|
||||||
|
save_dtype=get_torch_dtype('fp16'),
|
||||||
|
sd_version: Literal['1', '2', 'sdxl'] = '2'
|
||||||
|
):
|
||||||
|
if sd_version != 'sdxl':
|
||||||
|
# not supported yet
|
||||||
|
raise NotImplementedError("Only SDXL is supported at this time with this method")
|
||||||
|
# load our base
|
||||||
|
base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sdxl_ldm_base.safetensors')
|
||||||
|
mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sdxl.json')
|
||||||
|
|
||||||
|
# convert the state dict
|
||||||
|
converted_state_dict = convert_state_dict_to_ldm_with_mapping(
|
||||||
|
sd.state_dict(),
|
||||||
|
mapping_path,
|
||||||
|
base_path,
|
||||||
|
device='cpu',
|
||||||
|
dtype=save_dtype
|
||||||
|
)
|
||||||
|
# make sure parent folder exists
|
||||||
|
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||||
|
save_file(converted_state_dict, output_file, metadata=meta)
|
||||||
@@ -1,8 +1,9 @@
|
|||||||
import gc
|
import gc
|
||||||
import typing
|
import typing
|
||||||
from typing import Union, OrderedDict, List, Tuple
|
from typing import Union, List, Tuple
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
|
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
|
||||||
from safetensors.torch import save_file
|
from safetensors.torch import save_file
|
||||||
@@ -10,11 +11,12 @@ from tqdm import tqdm
|
|||||||
from torchvision.transforms import Resize
|
from torchvision.transforms import Resize
|
||||||
|
|
||||||
from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \
|
from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \
|
||||||
convert_vae_state_dict
|
convert_vae_state_dict, load_vae
|
||||||
from toolkit import train_tools
|
from toolkit import train_tools
|
||||||
from toolkit.config_modules import ModelConfig, GenerateImageConfig
|
from toolkit.config_modules import ModelConfig, GenerateImageConfig
|
||||||
from toolkit.metadata import get_meta_for_safetensors
|
from toolkit.metadata import get_meta_for_safetensors
|
||||||
from toolkit.paths import REPOS_ROOT
|
from toolkit.paths import REPOS_ROOT
|
||||||
|
from toolkit.saving import save_ldm_model_from_diffusers
|
||||||
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
||||||
import torch
|
import torch
|
||||||
from library import model_util
|
from library import model_util
|
||||||
@@ -27,6 +29,13 @@ import diffusers
|
|||||||
# tell it to shut up
|
# tell it to shut up
|
||||||
diffusers.logging.set_verbosity(diffusers.logging.ERROR)
|
diffusers.logging.set_verbosity(diffusers.logging.ERROR)
|
||||||
|
|
||||||
|
VAE_PREFIX_UNET = "vae"
|
||||||
|
SD_PREFIX_UNET = "unet"
|
||||||
|
SD_PREFIX_TEXT_ENCODER = "te"
|
||||||
|
|
||||||
|
SD_PREFIX_TEXT_ENCODER1 = "te1"
|
||||||
|
SD_PREFIX_TEXT_ENCODER2 = "te2"
|
||||||
|
|
||||||
|
|
||||||
class BlankNetwork:
|
class BlankNetwork:
|
||||||
multiplier = 1.0
|
multiplier = 1.0
|
||||||
@@ -218,6 +227,10 @@ class StableDiffusion:
|
|||||||
# scheduler doesn't get set sometimes, so we set it here
|
# scheduler doesn't get set sometimes, so we set it here
|
||||||
pipe.scheduler = scheduler
|
pipe.scheduler = scheduler
|
||||||
|
|
||||||
|
if self.model_config.vae_path is not None:
|
||||||
|
external_vae = load_vae(self.model_config.vae_path, dtype)
|
||||||
|
pipe.vae = external_vae
|
||||||
|
|
||||||
self.unet = pipe.unet
|
self.unet = pipe.unet
|
||||||
self.noise_scheduler = pipe.scheduler
|
self.noise_scheduler = pipe.scheduler
|
||||||
self.vae = pipe.vae.to(self.device_torch, dtype=dtype)
|
self.vae = pipe.vae.to(self.device_torch, dtype=dtype)
|
||||||
@@ -630,8 +643,33 @@ class StableDiffusion:
|
|||||||
|
|
||||||
raise ValueError(f"Unknown weight name: {name}")
|
raise ValueError(f"Unknown weight name: {name}")
|
||||||
|
|
||||||
|
def state_dict(self, vae=True, text_encoder=True, unet=True):
|
||||||
|
state_dict = OrderedDict()
|
||||||
|
if vae:
|
||||||
|
for k, v in self.vae.state_dict().items():
|
||||||
|
new_key = k if k.startswith(f"{VAE_PREFIX_UNET}") else f"{VAE_PREFIX_UNET}_{k}"
|
||||||
|
state_dict[new_key] = v
|
||||||
|
if text_encoder:
|
||||||
|
if isinstance(self.text_encoder, list):
|
||||||
|
for i, encoder in enumerate(self.text_encoder):
|
||||||
|
for k, v in encoder.state_dict().items():
|
||||||
|
new_key = k if k.startswith(
|
||||||
|
f"{SD_PREFIX_TEXT_ENCODER}{i}") else f"{SD_PREFIX_TEXT_ENCODER}{i}_{k}"
|
||||||
|
state_dict[new_key] = v
|
||||||
|
else:
|
||||||
|
for k, v in self.text_encoder.state_dict().items():
|
||||||
|
new_key = k if k.startswith(f"{SD_PREFIX_TEXT_ENCODER}") else f"{SD_PREFIX_TEXT_ENCODER}_{k}"
|
||||||
|
state_dict[new_key] = v
|
||||||
|
if unet:
|
||||||
|
for k, v in self.unet.state_dict().items():
|
||||||
|
new_key = k if k.startswith(f"{SD_PREFIX_UNET}") else f"{SD_PREFIX_UNET}_{k}"
|
||||||
|
state_dict[new_key] = v
|
||||||
|
return state_dict
|
||||||
|
|
||||||
def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None):
|
def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None):
|
||||||
state_dict = {}
|
state_dict = {}
|
||||||
|
# prepare metadata
|
||||||
|
meta = get_meta_for_safetensors(meta)
|
||||||
|
|
||||||
def update_sd(prefix, sd):
|
def update_sd(prefix, sd):
|
||||||
for k, v in sd.items():
|
for k, v in sd.items():
|
||||||
@@ -644,14 +682,13 @@ class StableDiffusion:
|
|||||||
|
|
||||||
# todo see what logit scale is
|
# todo see what logit scale is
|
||||||
if self.is_xl:
|
if self.is_xl:
|
||||||
# Convert the UNet model
|
save_ldm_model_from_diffusers(
|
||||||
update_sd("model.diffusion_model.", self.unet.state_dict())
|
sd=self,
|
||||||
|
output_file=output_file,
|
||||||
# Convert the text encoders
|
meta=meta,
|
||||||
update_sd("conditioner.embedders.0.transformer.", self.text_encoder[0].state_dict())
|
save_dtype=save_dtype,
|
||||||
|
sd_version='sdxl',
|
||||||
text_enc2_dict = convert_text_encoder_2_state_dict_to_sdxl(self.text_encoder[1].state_dict(), logit_scale)
|
)
|
||||||
update_sd("conditioner.embedders.1.model.", text_enc2_dict)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Convert the UNet model
|
# Convert the UNet model
|
||||||
@@ -667,13 +704,11 @@ class StableDiffusion:
|
|||||||
text_enc_dict = self.text_encoder.state_dict()
|
text_enc_dict = self.text_encoder.state_dict()
|
||||||
update_sd("cond_stage_model.transformer.", text_enc_dict)
|
update_sd("cond_stage_model.transformer.", text_enc_dict)
|
||||||
|
|
||||||
# Convert the VAE
|
# Convert the VAE
|
||||||
if self.vae is not None:
|
if self.vae is not None:
|
||||||
vae_dict = model_util.convert_vae_state_dict(self.vae.state_dict())
|
vae_dict = model_util.convert_vae_state_dict(self.vae.state_dict())
|
||||||
update_sd("first_stage_model.", vae_dict)
|
update_sd("first_stage_model.", vae_dict)
|
||||||
|
|
||||||
# prepare metadata
|
# make sure parent folder exists
|
||||||
meta = get_meta_for_safetensors(meta)
|
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||||
# make sure parent folder exists
|
save_file(state_dict, output_file, metadata=meta)
|
||||||
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
|
||||||
save_file(state_dict, output_file, metadata=meta)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user