Handle conversions back to ldm for saving

This commit is contained in:
Jaret Burkett
2023-07-18 19:34:35 -06:00
parent 17c13eef88
commit 8d8229dfc0
5 changed files with 586 additions and 61 deletions

View File

@@ -1,4 +1,5 @@
import copy
import glob
import os
import time
from collections import OrderedDict
@@ -12,8 +13,9 @@ from torch import nn
from torchvision.transforms import transforms
from jobs.process import BaseTrainProcess
from toolkit.kohya_model_util import load_vae
from toolkit.kohya_model_util import load_vae, convert_diffusers_back_to_ldm
from toolkit.data_loader import ImageDataset
from toolkit.losses import ComparativeTotalVariation
from toolkit.metadata import get_meta_for_safetensors
from toolkit.style import get_style_model_and_losses
from toolkit.train_tools import get_torch_dtype
@@ -57,7 +59,7 @@ class TrainVAEProcess(BaseTrainProcess):
self.content_weight = self.get_conf('content_weight', 0)
self.kld_weight = self.get_conf('kld_weight', 0)
self.mse_weight = self.get_conf('mse_weight', 1e0)
self.tv_weight = self.get_conf('tv_weight', 1e0)
self.blocks_to_train = self.get_conf('blocks_to_train', ['all'])
self.writer = self.job.writer
@@ -114,7 +116,7 @@ class TrainVAEProcess(BaseTrainProcess):
def setup_vgg19(self):
if self.vgg_19 is None:
self.vgg_19, self.style_losses, self.content_losses = get_style_model_and_losses(
self.vgg_19, self.style_losses, self.content_losses, output = get_style_model_and_losses(
single_target=True, device=self.device)
self.vgg_19.requires_grad_(False)
@@ -149,6 +151,15 @@ class TrainVAEProcess(BaseTrainProcess):
else:
return torch.tensor(0.0, device=self.device)
def get_tv_loss(self, pred, target):
if self.tv_weight > 0:
get_tv_loss = ComparativeTotalVariation()
loss = get_tv_loss(pred, target)
return loss
else:
return torch.tensor(0.0, device=self.device)
def save(self, step=None):
if not os.path.exists(self.save_root):
os.makedirs(self.save_root, exist_ok=True)
@@ -162,7 +173,7 @@ class TrainVAEProcess(BaseTrainProcess):
# prepare meta
save_meta = get_meta_for_safetensors(self.meta, self.job.name)
state_dict = self.vae.state_dict()
state_dict = convert_diffusers_back_to_ldm(self.vae)
for key in list(state_dict.keys()):
v = state_dict[key]
@@ -219,6 +230,30 @@ class TrainVAEProcess(BaseTrainProcess):
filename = f"{seconds_since_epoch}{step_num}_{i_str}.png"
output_img.save(os.path.join(sample_folder, filename))
def load_vae(self):
path_to_load = self.vae_path
# see if we have a checkpoint in out output to resume from
self.print(f"Looking for latest checkpoint in {self.save_root}")
files = glob.glob(os.path.join(self.save_root, f"{self.job.name}*.safetensors"))
if files and len(files) > 0:
latest_file = max(files, key=os.path.getmtime)
print(f" - Latest checkpoint is: {latest_file}")
path_to_load = latest_file
# todo update step and epoch count
else:
self.print(f" - No checkpoint found, starting from scratch")
# load vae
self.print(f"Loading VAE")
self.print(f" - Loading VAE: {path_to_load}")
if self.vae is None:
self.vae = load_vae(path_to_load, dtype=self.torch_dtype)
# set decoder to train
self.vae.to(self.device, dtype=self.torch_dtype)
self.vae.requires_grad_(False)
self.vae.eval()
self.vae.decoder.train()
def run(self):
super().run()
self.load_datasets()
@@ -241,16 +276,7 @@ class TrainVAEProcess(BaseTrainProcess):
self.print(f" - Max steps: {self.max_steps}")
# load vae
self.print(f"Loading VAE")
self.print(f" - Loading VAE: {self.vae_path}")
if self.vae is None:
self.vae = load_vae(self.vae_path, dtype=self.torch_dtype)
# set decoder to train
self.vae.to(self.device, dtype=self.torch_dtype)
self.vae.requires_grad_(False)
self.vae.eval()
self.vae.decoder.train()
self.load_vae()
params = []
@@ -260,18 +286,22 @@ class TrainVAEProcess(BaseTrainProcess):
train_all = 'all' in self.blocks_to_train
# mid_block
if train_all or 'mid_block' in self.blocks_to_train:
params += list(self.vae.decoder.mid_block.parameters())
self.vae.decoder.mid_block.requires_grad_(True)
# up_blocks
if train_all or 'up_blocks' in self.blocks_to_train:
params += list(self.vae.decoder.up_blocks.parameters())
self.vae.decoder.up_blocks.requires_grad_(True)
# conv_out (single conv layer output)
if train_all or 'conv_out' in self.blocks_to_train:
params += list(self.vae.decoder.conv_out.parameters())
self.vae.decoder.conv_out.requires_grad_(True)
if train_all:
params = list(self.vae.decoder.parameters())
self.vae.decoder.requires_grad_(True)
else:
# mid_block
if train_all or 'mid_block' in self.blocks_to_train:
params += list(self.vae.decoder.mid_block.parameters())
self.vae.decoder.mid_block.requires_grad_(True)
# up_blocks
if train_all or 'up_blocks' in self.blocks_to_train:
params += list(self.vae.decoder.up_blocks.parameters())
self.vae.decoder.up_blocks.requires_grad_(True)
# conv_out (single conv layer output)
if train_all or 'conv_out' in self.blocks_to_train:
params += list(self.vae.decoder.conv_out.parameters())
self.vae.decoder.conv_out.requires_grad_(True)
if self.style_weight > 0 or self.content_weight > 0:
self.setup_vgg19()
@@ -305,7 +335,8 @@ class TrainVAEProcess(BaseTrainProcess):
"style": [],
"content": [],
"mse": [],
"kl": []
"kl": [],
"tv": [],
})
epoch_losses = copy.deepcopy(blank_losses)
log_losses = copy.deepcopy(blank_losses)
@@ -337,8 +368,9 @@ class TrainVAEProcess(BaseTrainProcess):
content_loss = self.get_content_loss() * self.content_weight
kld_loss = self.get_kld_loss(mu, logvar) * self.kld_weight
mse_loss = self.get_mse_loss(pred, batch) * self.mse_weight
tv_loss = self.get_tv_loss(pred, batch) * self.tv_weight
loss = style_loss + content_loss + kld_loss + mse_loss
loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss
# Backward pass and optimization
optimizer.zero_grad()
@@ -358,6 +390,8 @@ class TrainVAEProcess(BaseTrainProcess):
loss_string += f" kld: {kld_loss.item():.2e}"
if self.mse_weight > 0:
loss_string += f" mse: {mse_loss.item():.2e}"
if self.tv_weight > 0:
loss_string += f" tv: {tv_loss.item():.2e}"
learning_rate = optimizer.param_groups[0]['lr']
self.progress_bar.set_postfix_str(f"LR: {learning_rate:.2e} {loss_string}")
@@ -369,12 +403,14 @@ class TrainVAEProcess(BaseTrainProcess):
epoch_losses["content"].append(content_loss.item())
epoch_losses["mse"].append(mse_loss.item())
epoch_losses["kl"].append(kld_loss.item())
epoch_losses["tv"].append(tv_loss.item())
log_losses["total"].append(loss_value)
log_losses["style"].append(style_loss.item())
log_losses["content"].append(content_loss.item())
log_losses["mse"].append(mse_loss.item())
log_losses["kl"].append(kld_loss.item())
log_losses["tv"].append(tv_loss.item())
if step != 0:
if self.sample_every and step % self.sample_every == 0:

112
testing/test_vae_cycle.py Normal file
View File

@@ -0,0 +1,112 @@
import os
import torch
from safetensors.torch import load_file
from collections import OrderedDict
from toolkit.kohya_model_util import load_vae, convert_diffusers_back_to_ldm, vae_keys_squished_on_diffusers
import json
# this was just used to match the vae keys to the diffusers keys
# you probably wont need this. Unless they change them.... again... again
# on second thought, you probably will
device = torch.device('cpu')
dtype = torch.float32
vae_path = '/mnt/Models/stable-diffusion/models/VAE/vae-ft-mse-840000-ema-pruned/vae-ft-mse-840000-ema-pruned.safetensors'
find_matches = False
state_dict_ldm = load_file(vae_path)
diffusers_vae = load_vae(vae_path, dtype=torch.float32).to(device)
ldm_keys = state_dict_ldm.keys()
matched_keys = {}
duplicated_keys = {
}
if find_matches:
# find values that match with a very low mse
for ldm_key in ldm_keys:
ldm_value = state_dict_ldm[ldm_key]
for diffusers_key in list(diffusers_vae.state_dict().keys()):
diffusers_value = diffusers_vae.state_dict()[diffusers_key]
if diffusers_key in vae_keys_squished_on_diffusers:
diffusers_value = diffusers_value.clone().unsqueeze(-1).unsqueeze(-1)
# if they are not same shape, skip
if ldm_value.shape != diffusers_value.shape:
continue
mse = torch.nn.functional.mse_loss(ldm_value, diffusers_value)
if mse < 1e-6:
if ldm_key in list(matched_keys.keys()):
print(f'{ldm_key} already matched to {matched_keys[ldm_key]}')
if ldm_key in duplicated_keys:
duplicated_keys[ldm_key].append(diffusers_key)
else:
duplicated_keys[ldm_key] = [diffusers_key]
continue
matched_keys[ldm_key] = diffusers_key
is_matched = True
break
print(f'Found {len(matched_keys)} matches')
dif_to_ldm_state_dict = convert_diffusers_back_to_ldm(diffusers_vae)
dif_to_ldm_state_dict_keys = list(dif_to_ldm_state_dict.keys())
keys_in_both = []
keys_not_in_diffusers = []
for key in ldm_keys:
if key not in dif_to_ldm_state_dict_keys:
keys_not_in_diffusers.append(key)
keys_not_in_ldm = []
for key in dif_to_ldm_state_dict_keys:
if key not in ldm_keys:
keys_not_in_ldm.append(key)
keys_in_both = []
for key in ldm_keys:
if key in dif_to_ldm_state_dict_keys:
keys_in_both.append(key)
# sort them
keys_not_in_diffusers.sort()
keys_not_in_ldm.sort()
keys_in_both.sort()
# print(f'Keys in LDM but not in Diffusers: {len(keys_not_in_diffusers)}{keys_not_in_diffusers}')
# print(f'Keys in Diffusers but not in LDM: {len(keys_not_in_ldm)}{keys_not_in_ldm}')
# print(f'Keys in both: {len(keys_in_both)}{keys_in_both}')
json_data = {
"both": keys_in_both,
"ldm": keys_not_in_diffusers,
"diffusers": keys_not_in_ldm
}
json_data = json.dumps(json_data, indent=4)
remaining_diffusers_values = OrderedDict()
for key in keys_not_in_ldm:
remaining_diffusers_values[key] = dif_to_ldm_state_dict[key]
# print(remaining_diffusers_values.keys())
remaining_ldm_values = OrderedDict()
for key in keys_not_in_diffusers:
remaining_ldm_values[key] = state_dict_ldm[key]
# print(json_data)
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
json_save_path = os.path.join(project_root, 'config', 'keys.json')
json_matched_save_path = os.path.join(project_root, 'config', 'matched.json')
json_duped_save_path = os.path.join(project_root, 'config', 'duped.json')
with open(json_save_path, 'w') as f:
f.write(json_data)
if find_matches:
with open(json_matched_save_path, 'w') as f:
f.write(json.dumps(matched_keys, indent=4))
with open(json_duped_save_path, 'w') as f:
f.write(json.dumps(duplicated_keys, indent=4))

View File

@@ -1,16 +1,19 @@
# mostly from https://github.com/kohya-ss/sd-scripts/blob/main/library/model_util.py
# I am infinitely grateful to @kohya-ss for their amazing work in this field.
# This version is updated to handle the latest version of the diffusers library.
import json
# v1: split from train_db_fixed.py.
# v2: support safetensors
import math
import os
import re
import torch
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from safetensors.torch import load_file, save_file
from collections import OrderedDict
# DiffUsers版StableDiffusionのモデルパラメータ
NUM_TRAIN_TIMESTEPS = 1000
@@ -151,7 +154,7 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
def assign_to_checkpoint(
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
):
"""
This does the final conversion step: take locally converted weights and apply a global renaming
@@ -228,6 +231,7 @@ def linear_transformer_to_conv(checkpoint):
def convert_ldm_unet_checkpoint(v2, checkpoint, config):
mapping = {}
"""
Takes a state dict and a config, and returns a converted checkpoint.
"""
@@ -258,33 +262,40 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
# Retrieves the keys for the input blocks only
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
input_blocks = {
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] for layer_id in range(num_input_blocks)
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] for layer_id in
range(num_input_blocks)
}
# Retrieves the keys for the middle blocks only
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
middle_blocks = {
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key] for layer_id in range(num_middle_blocks)
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key] for layer_id in
range(num_middle_blocks)
}
# Retrieves the keys for the output blocks only
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
output_blocks = {
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] for layer_id in range(num_output_blocks)
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] for layer_id in
range(num_output_blocks)
}
for i in range(1, num_input_blocks):
block_id = (i - 1) // (config["layers_per_block"] + 1)
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key]
resnets = [key for key in input_blocks[i] if
f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key]
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
f"input_blocks.{i}.0.op.weight"
)
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(f"input_blocks.{i}.0.op.bias")
mapping[f'input_blocks.{i}.0.op.weight'] = f"down_blocks.{block_id}.downsamplers.0.conv.weight"
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
f"input_blocks.{i}.0.op.bias")
mapping[f'input_blocks.{i}.0.op.bias'] = f"down_blocks.{block_id}.downsamplers.0.conv.bias"
paths = renew_resnet_paths(resnets)
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
@@ -293,7 +304,8 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
if len(attentions):
paths = renew_attention_paths(attentions)
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path],
config=config)
resnet_0 = middle_blocks[0]
attentions = middle_blocks[1]
@@ -307,7 +319,8 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
attentions_paths = renew_attention_paths(attentions)
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
assign_to_checkpoint(attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
assign_to_checkpoint(attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path],
config=config)
for i in range(num_output_blocks):
block_id = i // (config["layers_per_block"] + 1)
@@ -330,7 +343,8 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
paths = renew_resnet_paths(resnets)
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path],
config=config)
# オリジナル:
# if ["conv.weight", "conv.bias"] in output_block_list.values():
@@ -359,7 +373,8 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
"old": f"output_blocks.{i}.1",
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
}
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path],
config=config)
else:
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
for path in resnet_0_paths:
@@ -373,10 +388,326 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
if v2 and not config.get('use_linear_projection', False):
linear_transformer_to_conv(new_checkpoint)
# print("mapping: ", json.dumps(mapping, indent=4))
return new_checkpoint
# ldm key: diffusers key
vae_ldm_to_diffusers_dict = {
"decoder.conv_in.bias": "decoder.conv_in.bias",
"decoder.conv_in.weight": "decoder.conv_in.weight",
"decoder.conv_out.bias": "decoder.conv_out.bias",
"decoder.conv_out.weight": "decoder.conv_out.weight",
"decoder.mid.attn_1.k.bias": "decoder.mid_block.attentions.0.to_k.bias",
"decoder.mid.attn_1.k.weight": "decoder.mid_block.attentions.0.to_k.weight",
"decoder.mid.attn_1.norm.bias": "decoder.mid_block.attentions.0.group_norm.bias",
"decoder.mid.attn_1.norm.weight": "decoder.mid_block.attentions.0.group_norm.weight",
"decoder.mid.attn_1.proj_out.bias": "decoder.mid_block.attentions.0.to_out.0.bias",
"decoder.mid.attn_1.proj_out.weight": "decoder.mid_block.attentions.0.to_out.0.weight",
"decoder.mid.attn_1.q.bias": "decoder.mid_block.attentions.0.to_q.bias",
"decoder.mid.attn_1.q.weight": "decoder.mid_block.attentions.0.to_q.weight",
"decoder.mid.attn_1.v.bias": "decoder.mid_block.attentions.0.to_v.bias",
"decoder.mid.attn_1.v.weight": "decoder.mid_block.attentions.0.to_v.weight",
"decoder.mid.block_1.conv1.bias": "decoder.mid_block.resnets.0.conv1.bias",
"decoder.mid.block_1.conv1.weight": "decoder.mid_block.resnets.0.conv1.weight",
"decoder.mid.block_1.conv2.bias": "decoder.mid_block.resnets.0.conv2.bias",
"decoder.mid.block_1.conv2.weight": "decoder.mid_block.resnets.0.conv2.weight",
"decoder.mid.block_1.norm1.bias": "decoder.mid_block.resnets.0.norm1.bias",
"decoder.mid.block_1.norm1.weight": "decoder.mid_block.resnets.0.norm1.weight",
"decoder.mid.block_1.norm2.bias": "decoder.mid_block.resnets.0.norm2.bias",
"decoder.mid.block_1.norm2.weight": "decoder.mid_block.resnets.0.norm2.weight",
"decoder.mid.block_2.conv1.bias": "decoder.mid_block.resnets.1.conv1.bias",
"decoder.mid.block_2.conv1.weight": "decoder.mid_block.resnets.1.conv1.weight",
"decoder.mid.block_2.conv2.bias": "decoder.mid_block.resnets.1.conv2.bias",
"decoder.mid.block_2.conv2.weight": "decoder.mid_block.resnets.1.conv2.weight",
"decoder.mid.block_2.norm1.bias": "decoder.mid_block.resnets.1.norm1.bias",
"decoder.mid.block_2.norm1.weight": "decoder.mid_block.resnets.1.norm1.weight",
"decoder.mid.block_2.norm2.bias": "decoder.mid_block.resnets.1.norm2.bias",
"decoder.mid.block_2.norm2.weight": "decoder.mid_block.resnets.1.norm2.weight",
"decoder.norm_out.bias": "decoder.conv_norm_out.bias",
"decoder.norm_out.weight": "decoder.conv_norm_out.weight",
"decoder.up.0.block.0.conv1.bias": "decoder.up_blocks.3.resnets.0.conv1.bias",
"decoder.up.0.block.0.conv1.weight": "decoder.up_blocks.3.resnets.0.conv1.weight",
"decoder.up.0.block.0.conv2.bias": "decoder.up_blocks.3.resnets.0.conv2.bias",
"decoder.up.0.block.0.conv2.weight": "decoder.up_blocks.3.resnets.0.conv2.weight",
"decoder.up.0.block.0.nin_shortcut.bias": "decoder.up_blocks.3.resnets.0.conv_shortcut.bias",
"decoder.up.0.block.0.nin_shortcut.weight": "decoder.up_blocks.3.resnets.0.conv_shortcut.weight",
"decoder.up.0.block.0.norm1.bias": "decoder.up_blocks.3.resnets.0.norm1.bias",
"decoder.up.0.block.0.norm1.weight": "decoder.up_blocks.3.resnets.0.norm1.weight",
"decoder.up.0.block.0.norm2.bias": "decoder.up_blocks.3.resnets.0.norm2.bias",
"decoder.up.0.block.0.norm2.weight": "decoder.up_blocks.3.resnets.0.norm2.weight",
"decoder.up.0.block.1.conv1.bias": "decoder.up_blocks.3.resnets.1.conv1.bias",
"decoder.up.0.block.1.conv1.weight": "decoder.up_blocks.3.resnets.1.conv1.weight",
"decoder.up.0.block.1.conv2.bias": "decoder.up_blocks.3.resnets.1.conv2.bias",
"decoder.up.0.block.1.conv2.weight": "decoder.up_blocks.3.resnets.1.conv2.weight",
"decoder.up.0.block.1.norm1.bias": "decoder.up_blocks.3.resnets.1.norm1.bias",
"decoder.up.0.block.1.norm1.weight": "decoder.up_blocks.3.resnets.1.norm1.weight",
"decoder.up.0.block.1.norm2.bias": "decoder.up_blocks.3.resnets.1.norm2.bias",
"decoder.up.0.block.1.norm2.weight": "decoder.up_blocks.3.resnets.1.norm2.weight",
"decoder.up.0.block.2.conv1.bias": "decoder.up_blocks.3.resnets.2.conv1.bias",
"decoder.up.0.block.2.conv1.weight": "decoder.up_blocks.3.resnets.2.conv1.weight",
"decoder.up.0.block.2.conv2.bias": "decoder.up_blocks.3.resnets.2.conv2.bias",
"decoder.up.0.block.2.conv2.weight": "decoder.up_blocks.3.resnets.2.conv2.weight",
"decoder.up.0.block.2.norm1.bias": "decoder.up_blocks.3.resnets.2.norm1.bias",
"decoder.up.0.block.2.norm1.weight": "decoder.up_blocks.3.resnets.2.norm1.weight",
"decoder.up.0.block.2.norm2.bias": "decoder.up_blocks.3.resnets.2.norm2.bias",
"decoder.up.0.block.2.norm2.weight": "decoder.up_blocks.3.resnets.2.norm2.weight",
"decoder.up.1.block.0.conv1.bias": "decoder.up_blocks.2.resnets.0.conv1.bias",
"decoder.up.1.block.0.conv1.weight": "decoder.up_blocks.2.resnets.0.conv1.weight",
"decoder.up.1.block.0.conv2.bias": "decoder.up_blocks.2.resnets.0.conv2.bias",
"decoder.up.1.block.0.conv2.weight": "decoder.up_blocks.2.resnets.0.conv2.weight",
"decoder.up.1.block.0.nin_shortcut.bias": "decoder.up_blocks.2.resnets.0.conv_shortcut.bias",
"decoder.up.1.block.0.nin_shortcut.weight": "decoder.up_blocks.2.resnets.0.conv_shortcut.weight",
"decoder.up.1.block.0.norm1.bias": "decoder.up_blocks.2.resnets.0.norm1.bias",
"decoder.up.1.block.0.norm1.weight": "decoder.up_blocks.2.resnets.0.norm1.weight",
"decoder.up.1.block.0.norm2.bias": "decoder.up_blocks.2.resnets.0.norm2.bias",
"decoder.up.1.block.0.norm2.weight": "decoder.up_blocks.2.resnets.0.norm2.weight",
"decoder.up.1.block.1.conv1.bias": "decoder.up_blocks.2.resnets.1.conv1.bias",
"decoder.up.1.block.1.conv1.weight": "decoder.up_blocks.2.resnets.1.conv1.weight",
"decoder.up.1.block.1.conv2.bias": "decoder.up_blocks.2.resnets.1.conv2.bias",
"decoder.up.1.block.1.conv2.weight": "decoder.up_blocks.2.resnets.1.conv2.weight",
"decoder.up.1.block.1.norm1.bias": "decoder.up_blocks.2.resnets.1.norm1.bias",
"decoder.up.1.block.1.norm1.weight": "decoder.up_blocks.2.resnets.1.norm1.weight",
"decoder.up.1.block.1.norm2.bias": "decoder.up_blocks.2.resnets.1.norm2.bias",
"decoder.up.1.block.1.norm2.weight": "decoder.up_blocks.2.resnets.1.norm2.weight",
"decoder.up.1.block.2.conv1.bias": "decoder.up_blocks.2.resnets.2.conv1.bias",
"decoder.up.1.block.2.conv1.weight": "decoder.up_blocks.2.resnets.2.conv1.weight",
"decoder.up.1.block.2.conv2.bias": "decoder.up_blocks.2.resnets.2.conv2.bias",
"decoder.up.1.block.2.conv2.weight": "decoder.up_blocks.2.resnets.2.conv2.weight",
"decoder.up.1.block.2.norm1.bias": "decoder.up_blocks.2.resnets.2.norm1.bias",
"decoder.up.1.block.2.norm1.weight": "decoder.up_blocks.2.resnets.2.norm1.weight",
"decoder.up.1.block.2.norm2.bias": "decoder.up_blocks.2.resnets.2.norm2.bias",
"decoder.up.1.block.2.norm2.weight": "decoder.up_blocks.2.resnets.2.norm2.weight",
"decoder.up.1.upsample.conv.bias": "decoder.up_blocks.2.upsamplers.0.conv.bias",
"decoder.up.1.upsample.conv.weight": "decoder.up_blocks.2.upsamplers.0.conv.weight",
"decoder.up.2.block.0.conv1.bias": "decoder.up_blocks.1.resnets.0.conv1.bias",
"decoder.up.2.block.0.conv1.weight": "decoder.up_blocks.1.resnets.0.conv1.weight",
"decoder.up.2.block.0.conv2.bias": "decoder.up_blocks.1.resnets.0.conv2.bias",
"decoder.up.2.block.0.conv2.weight": "decoder.up_blocks.1.resnets.0.conv2.weight",
"decoder.up.2.block.0.norm1.bias": "decoder.up_blocks.1.resnets.0.norm1.bias",
"decoder.up.2.block.0.norm1.weight": "decoder.up_blocks.1.resnets.0.norm1.weight",
"decoder.up.2.block.0.norm2.bias": "decoder.up_blocks.1.resnets.0.norm2.bias",
"decoder.up.2.block.0.norm2.weight": "decoder.up_blocks.1.resnets.0.norm2.weight",
"decoder.up.2.block.1.conv1.bias": "decoder.up_blocks.1.resnets.1.conv1.bias",
"decoder.up.2.block.1.conv1.weight": "decoder.up_blocks.1.resnets.1.conv1.weight",
"decoder.up.2.block.1.conv2.bias": "decoder.up_blocks.1.resnets.1.conv2.bias",
"decoder.up.2.block.1.conv2.weight": "decoder.up_blocks.1.resnets.1.conv2.weight",
"decoder.up.2.block.1.norm1.bias": "decoder.up_blocks.1.resnets.1.norm1.bias",
"decoder.up.2.block.1.norm1.weight": "decoder.up_blocks.1.resnets.1.norm1.weight",
"decoder.up.2.block.1.norm2.bias": "decoder.up_blocks.1.resnets.1.norm2.bias",
"decoder.up.2.block.1.norm2.weight": "decoder.up_blocks.1.resnets.1.norm2.weight",
"decoder.up.2.block.2.conv1.bias": "decoder.up_blocks.1.resnets.2.conv1.bias",
"decoder.up.2.block.2.conv1.weight": "decoder.up_blocks.1.resnets.2.conv1.weight",
"decoder.up.2.block.2.conv2.bias": "decoder.up_blocks.1.resnets.2.conv2.bias",
"decoder.up.2.block.2.conv2.weight": "decoder.up_blocks.1.resnets.2.conv2.weight",
"decoder.up.2.block.2.norm1.bias": "decoder.up_blocks.1.resnets.2.norm1.bias",
"decoder.up.2.block.2.norm1.weight": "decoder.up_blocks.1.resnets.2.norm1.weight",
"decoder.up.2.block.2.norm2.bias": "decoder.up_blocks.1.resnets.2.norm2.bias",
"decoder.up.2.block.2.norm2.weight": "decoder.up_blocks.1.resnets.2.norm2.weight",
"decoder.up.2.upsample.conv.bias": "decoder.up_blocks.1.upsamplers.0.conv.bias",
"decoder.up.2.upsample.conv.weight": "decoder.up_blocks.1.upsamplers.0.conv.weight",
"decoder.up.3.block.0.conv1.bias": "decoder.up_blocks.0.resnets.0.conv1.bias",
"decoder.up.3.block.0.conv1.weight": "decoder.up_blocks.0.resnets.0.conv1.weight",
"decoder.up.3.block.0.conv2.bias": "decoder.up_blocks.0.resnets.0.conv2.bias",
"decoder.up.3.block.0.conv2.weight": "decoder.up_blocks.0.resnets.0.conv2.weight",
"decoder.up.3.block.0.norm1.bias": "decoder.up_blocks.0.resnets.0.norm1.bias",
"decoder.up.3.block.0.norm1.weight": "decoder.up_blocks.0.resnets.0.norm1.weight",
"decoder.up.3.block.0.norm2.bias": "decoder.up_blocks.0.resnets.0.norm2.bias",
"decoder.up.3.block.0.norm2.weight": "decoder.up_blocks.0.resnets.0.norm2.weight",
"decoder.up.3.block.1.conv1.bias": "decoder.up_blocks.0.resnets.1.conv1.bias",
"decoder.up.3.block.1.conv1.weight": "decoder.up_blocks.0.resnets.1.conv1.weight",
"decoder.up.3.block.1.conv2.bias": "decoder.up_blocks.0.resnets.1.conv2.bias",
"decoder.up.3.block.1.conv2.weight": "decoder.up_blocks.0.resnets.1.conv2.weight",
"decoder.up.3.block.1.norm1.bias": "decoder.up_blocks.0.resnets.1.norm1.bias",
"decoder.up.3.block.1.norm1.weight": "decoder.up_blocks.0.resnets.1.norm1.weight",
"decoder.up.3.block.1.norm2.bias": "decoder.up_blocks.0.resnets.1.norm2.bias",
"decoder.up.3.block.1.norm2.weight": "decoder.up_blocks.0.resnets.1.norm2.weight",
"decoder.up.3.block.2.conv1.bias": "decoder.up_blocks.0.resnets.2.conv1.bias",
"decoder.up.3.block.2.conv1.weight": "decoder.up_blocks.0.resnets.2.conv1.weight",
"decoder.up.3.block.2.conv2.bias": "decoder.up_blocks.0.resnets.2.conv2.bias",
"decoder.up.3.block.2.conv2.weight": "decoder.up_blocks.0.resnets.2.conv2.weight",
"decoder.up.3.block.2.norm1.bias": "decoder.up_blocks.0.resnets.2.norm1.bias",
"decoder.up.3.block.2.norm1.weight": "decoder.up_blocks.0.resnets.2.norm1.weight",
"decoder.up.3.block.2.norm2.bias": "decoder.up_blocks.0.resnets.2.norm2.bias",
"decoder.up.3.block.2.norm2.weight": "decoder.up_blocks.0.resnets.2.norm2.weight",
"decoder.up.3.upsample.conv.bias": "decoder.up_blocks.0.upsamplers.0.conv.bias",
"decoder.up.3.upsample.conv.weight": "decoder.up_blocks.0.upsamplers.0.conv.weight",
"encoder.conv_in.bias": "encoder.conv_in.bias",
"encoder.conv_in.weight": "encoder.conv_in.weight",
"encoder.conv_out.bias": "encoder.conv_out.bias",
"encoder.conv_out.weight": "encoder.conv_out.weight",
"encoder.down.0.block.0.conv1.bias": "encoder.down_blocks.0.resnets.0.conv1.bias",
"encoder.down.0.block.0.conv1.weight": "encoder.down_blocks.0.resnets.0.conv1.weight",
"encoder.down.0.block.0.conv2.bias": "encoder.down_blocks.0.resnets.0.conv2.bias",
"encoder.down.0.block.0.conv2.weight": "encoder.down_blocks.0.resnets.0.conv2.weight",
"encoder.down.0.block.0.norm1.bias": "encoder.down_blocks.0.resnets.0.norm1.bias",
"encoder.down.0.block.0.norm1.weight": "encoder.down_blocks.0.resnets.0.norm1.weight",
"encoder.down.0.block.0.norm2.bias": "encoder.down_blocks.0.resnets.0.norm2.bias",
"encoder.down.0.block.0.norm2.weight": "encoder.down_blocks.0.resnets.0.norm2.weight",
"encoder.down.0.block.1.conv1.bias": "encoder.down_blocks.0.resnets.1.conv1.bias",
"encoder.down.0.block.1.conv1.weight": "encoder.down_blocks.0.resnets.1.conv1.weight",
"encoder.down.0.block.1.conv2.bias": "encoder.down_blocks.0.resnets.1.conv2.bias",
"encoder.down.0.block.1.conv2.weight": "encoder.down_blocks.0.resnets.1.conv2.weight",
"encoder.down.0.block.1.norm1.bias": "encoder.down_blocks.0.resnets.1.norm1.bias",
"encoder.down.0.block.1.norm1.weight": "encoder.down_blocks.0.resnets.1.norm1.weight",
"encoder.down.0.block.1.norm2.bias": "encoder.down_blocks.0.resnets.1.norm2.bias",
"encoder.down.0.block.1.norm2.weight": "encoder.down_blocks.0.resnets.1.norm2.weight",
"encoder.down.0.downsample.conv.bias": "encoder.down_blocks.0.downsamplers.0.conv.bias",
"encoder.down.0.downsample.conv.weight": "encoder.down_blocks.0.downsamplers.0.conv.weight",
"encoder.down.1.block.0.conv1.bias": "encoder.down_blocks.1.resnets.0.conv1.bias",
"encoder.down.1.block.0.conv1.weight": "encoder.down_blocks.1.resnets.0.conv1.weight",
"encoder.down.1.block.0.conv2.bias": "encoder.down_blocks.1.resnets.0.conv2.bias",
"encoder.down.1.block.0.conv2.weight": "encoder.down_blocks.1.resnets.0.conv2.weight",
"encoder.down.1.block.0.nin_shortcut.bias": "encoder.down_blocks.1.resnets.0.conv_shortcut.bias",
"encoder.down.1.block.0.nin_shortcut.weight": "encoder.down_blocks.1.resnets.0.conv_shortcut.weight",
"encoder.down.1.block.0.norm1.bias": "encoder.down_blocks.1.resnets.0.norm1.bias",
"encoder.down.1.block.0.norm1.weight": "encoder.down_blocks.1.resnets.0.norm1.weight",
"encoder.down.1.block.0.norm2.bias": "encoder.down_blocks.1.resnets.0.norm2.bias",
"encoder.down.1.block.0.norm2.weight": "encoder.down_blocks.1.resnets.0.norm2.weight",
"encoder.down.1.block.1.conv1.bias": "encoder.down_blocks.1.resnets.1.conv1.bias",
"encoder.down.1.block.1.conv1.weight": "encoder.down_blocks.1.resnets.1.conv1.weight",
"encoder.down.1.block.1.conv2.bias": "encoder.down_blocks.1.resnets.1.conv2.bias",
"encoder.down.1.block.1.conv2.weight": "encoder.down_blocks.1.resnets.1.conv2.weight",
"encoder.down.1.block.1.norm1.bias": "encoder.down_blocks.1.resnets.1.norm1.bias",
"encoder.down.1.block.1.norm1.weight": "encoder.down_blocks.1.resnets.1.norm1.weight",
"encoder.down.1.block.1.norm2.bias": "encoder.down_blocks.1.resnets.1.norm2.bias",
"encoder.down.1.block.1.norm2.weight": "encoder.down_blocks.1.resnets.1.norm2.weight",
"encoder.down.1.downsample.conv.bias": "encoder.down_blocks.1.downsamplers.0.conv.bias",
"encoder.down.1.downsample.conv.weight": "encoder.down_blocks.1.downsamplers.0.conv.weight",
"encoder.down.2.block.0.conv1.bias": "encoder.down_blocks.2.resnets.0.conv1.bias",
"encoder.down.2.block.0.conv1.weight": "encoder.down_blocks.2.resnets.0.conv1.weight",
"encoder.down.2.block.0.conv2.bias": "encoder.down_blocks.2.resnets.0.conv2.bias",
"encoder.down.2.block.0.conv2.weight": "encoder.down_blocks.2.resnets.0.conv2.weight",
"encoder.down.2.block.0.nin_shortcut.bias": "encoder.down_blocks.2.resnets.0.conv_shortcut.bias",
"encoder.down.2.block.0.nin_shortcut.weight": "encoder.down_blocks.2.resnets.0.conv_shortcut.weight",
"encoder.down.2.block.0.norm1.bias": "encoder.down_blocks.2.resnets.0.norm1.bias",
"encoder.down.2.block.0.norm1.weight": "encoder.down_blocks.2.resnets.0.norm1.weight",
"encoder.down.2.block.0.norm2.bias": "encoder.down_blocks.2.resnets.0.norm2.bias",
"encoder.down.2.block.0.norm2.weight": "encoder.down_blocks.2.resnets.0.norm2.weight",
"encoder.down.2.block.1.conv1.bias": "encoder.down_blocks.2.resnets.1.conv1.bias",
"encoder.down.2.block.1.conv1.weight": "encoder.down_blocks.2.resnets.1.conv1.weight",
"encoder.down.2.block.1.conv2.bias": "encoder.down_blocks.2.resnets.1.conv2.bias",
"encoder.down.2.block.1.conv2.weight": "encoder.down_blocks.2.resnets.1.conv2.weight",
"encoder.down.2.block.1.norm1.bias": "encoder.down_blocks.2.resnets.1.norm1.bias",
"encoder.down.2.block.1.norm1.weight": "encoder.down_blocks.2.resnets.1.norm1.weight",
"encoder.down.2.block.1.norm2.bias": "encoder.down_blocks.2.resnets.1.norm2.bias",
"encoder.down.2.block.1.norm2.weight": "encoder.down_blocks.2.resnets.1.norm2.weight",
"encoder.down.2.downsample.conv.bias": "encoder.down_blocks.2.downsamplers.0.conv.bias",
"encoder.down.2.downsample.conv.weight": "encoder.down_blocks.2.downsamplers.0.conv.weight",
"encoder.down.3.block.0.conv1.bias": "encoder.down_blocks.3.resnets.0.conv1.bias",
"encoder.down.3.block.0.conv1.weight": "encoder.down_blocks.3.resnets.0.conv1.weight",
"encoder.down.3.block.0.conv2.bias": "encoder.down_blocks.3.resnets.0.conv2.bias",
"encoder.down.3.block.0.conv2.weight": "encoder.down_blocks.3.resnets.0.conv2.weight",
"encoder.down.3.block.0.norm1.bias": "encoder.down_blocks.3.resnets.0.norm1.bias",
"encoder.down.3.block.0.norm1.weight": "encoder.down_blocks.3.resnets.0.norm1.weight",
"encoder.down.3.block.0.norm2.bias": "encoder.down_blocks.3.resnets.0.norm2.bias",
"encoder.down.3.block.0.norm2.weight": "encoder.down_blocks.3.resnets.0.norm2.weight",
"encoder.down.3.block.1.conv1.bias": "encoder.down_blocks.3.resnets.1.conv1.bias",
"encoder.down.3.block.1.conv1.weight": "encoder.down_blocks.3.resnets.1.conv1.weight",
"encoder.down.3.block.1.conv2.bias": "encoder.down_blocks.3.resnets.1.conv2.bias",
"encoder.down.3.block.1.conv2.weight": "encoder.down_blocks.3.resnets.1.conv2.weight",
"encoder.down.3.block.1.norm1.bias": "encoder.down_blocks.3.resnets.1.norm1.bias",
"encoder.down.3.block.1.norm1.weight": "encoder.down_blocks.3.resnets.1.norm1.weight",
"encoder.down.3.block.1.norm2.bias": "encoder.down_blocks.3.resnets.1.norm2.bias",
"encoder.down.3.block.1.norm2.weight": "encoder.down_blocks.3.resnets.1.norm2.weight",
"encoder.mid.attn_1.k.bias": "encoder.mid_block.attentions.0.to_k.bias",
"encoder.mid.attn_1.k.weight": "encoder.mid_block.attentions.0.to_k.weight",
"encoder.mid.attn_1.norm.bias": "encoder.mid_block.attentions.0.group_norm.bias",
"encoder.mid.attn_1.norm.weight": "encoder.mid_block.attentions.0.group_norm.weight",
"encoder.mid.attn_1.proj_out.bias": "encoder.mid_block.attentions.0.to_out.0.bias",
"encoder.mid.attn_1.proj_out.weight": "encoder.mid_block.attentions.0.to_out.0.weight",
"encoder.mid.attn_1.q.bias": "encoder.mid_block.attentions.0.to_q.bias",
"encoder.mid.attn_1.q.weight": "encoder.mid_block.attentions.0.to_q.weight",
"encoder.mid.attn_1.v.bias": "encoder.mid_block.attentions.0.to_v.bias",
"encoder.mid.attn_1.v.weight": "encoder.mid_block.attentions.0.to_v.weight",
"encoder.mid.block_1.conv1.bias": "encoder.mid_block.resnets.0.conv1.bias",
"encoder.mid.block_1.conv1.weight": "encoder.mid_block.resnets.0.conv1.weight",
"encoder.mid.block_1.conv2.bias": "encoder.mid_block.resnets.0.conv2.bias",
"encoder.mid.block_1.conv2.weight": "encoder.mid_block.resnets.0.conv2.weight",
"encoder.mid.block_1.norm1.bias": "encoder.mid_block.resnets.0.norm1.bias",
"encoder.mid.block_1.norm1.weight": "encoder.mid_block.resnets.0.norm1.weight",
"encoder.mid.block_1.norm2.bias": "encoder.mid_block.resnets.0.norm2.bias",
"encoder.mid.block_1.norm2.weight": "encoder.mid_block.resnets.0.norm2.weight",
"encoder.mid.block_2.conv1.bias": "encoder.mid_block.resnets.1.conv1.bias",
"encoder.mid.block_2.conv1.weight": "encoder.mid_block.resnets.1.conv1.weight",
"encoder.mid.block_2.conv2.bias": "encoder.mid_block.resnets.1.conv2.bias",
"encoder.mid.block_2.conv2.weight": "encoder.mid_block.resnets.1.conv2.weight",
"encoder.mid.block_2.norm1.bias": "encoder.mid_block.resnets.1.norm1.bias",
"encoder.mid.block_2.norm1.weight": "encoder.mid_block.resnets.1.norm1.weight",
"encoder.mid.block_2.norm2.bias": "encoder.mid_block.resnets.1.norm2.bias",
"encoder.mid.block_2.norm2.weight": "encoder.mid_block.resnets.1.norm2.weight",
"encoder.norm_out.bias": "encoder.conv_norm_out.bias",
"encoder.norm_out.weight": "encoder.conv_norm_out.weight",
"post_quant_conv.bias": "post_quant_conv.bias",
"post_quant_conv.weight": "post_quant_conv.weight",
"quant_conv.bias": "quant_conv.bias",
"quant_conv.weight": "quant_conv.weight"
}
def get_diffusers_vae_key_from_ldm_key(target_ldm_key, i=None):
for ldm_key, diffusers_key in vae_ldm_to_diffusers_dict.items():
if i is not None:
ldm_key = ldm_key.replace("{i}", str(i))
diffusers_key = diffusers_key.replace("{i}", str(i))
if ldm_key == target_ldm_key:
return diffusers_key
if ldm_key in vae_ldm_to_diffusers_dict:
return vae_ldm_to_diffusers_dict[ldm_key]
else:
return None
# def get_ldm_vae_key_from_diffusers_key(target_diffusers_key):
# for ldm_key, diffusers_key in vae_ldm_to_diffusers_dict.items():
# if diffusers_key == target_diffusers_key:
# return ldm_key
# return None
def get_ldm_vae_key_from_diffusers_key(target_diffusers_key):
for ldm_key, diffusers_key in vae_ldm_to_diffusers_dict.items():
if "{" in diffusers_key: # if we have a placeholder
# escape special characters in the key, and replace the placeholder with a regex group
pattern = re.escape(diffusers_key).replace("\\{i\\}", "(\\d+)")
match = re.match(pattern, target_diffusers_key)
if match: # if we found a match
return ldm_key.format(i=match.group(1))
elif diffusers_key == target_diffusers_key:
return ldm_key
return None
vae_keys_squished_on_diffusers = [
"decoder.mid_block.attentions.0.to_k.weight",
"decoder.mid_block.attentions.0.to_out.0.weight",
"decoder.mid_block.attentions.0.to_q.weight",
"decoder.mid_block.attentions.0.to_v.weight",
"encoder.mid_block.attentions.0.to_k.weight",
"encoder.mid_block.attentions.0.to_out.0.weight",
"encoder.mid_block.attentions.0.to_q.weight",
"encoder.mid_block.attentions.0.to_v.weight"
]
def convert_diffusers_back_to_ldm(diffusers_vae):
new_state_dict = OrderedDict()
diffusers_state_dict = diffusers_vae.state_dict()
for key, value in diffusers_state_dict.items():
val_to_save = value
if key in vae_keys_squished_on_diffusers:
val_to_save = value.clone()
# (512, 512) diffusers and (512, 512, 1, 1) ldm
val_to_save = val_to_save.unsqueeze(-1).unsqueeze(-1)
ldm_key = get_ldm_vae_key_from_diffusers_key(key)
if ldm_key is not None:
new_state_dict[ldm_key] = val_to_save
else:
# for now add current key
new_state_dict[key] = val_to_save
return new_state_dict
def convert_ldm_vae_checkpoint(checkpoint, config):
mapping = {}
# extract state dict for VAE
vae_state_dict = {}
vae_key = "first_stage_model."
@@ -390,6 +721,11 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
new_checkpoint = {}
# for key in list(vae_state_dict.keys()):
# diffusers_key = get_diffusers_vae_key_from_ldm_key(key)
# if diffusers_key is not None:
# new_checkpoint[diffusers_key] = vae_state_dict[key]
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
@@ -411,11 +747,13 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
# Retrieves the keys for the encoder down blocks only
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
down_blocks = {layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)}
down_blocks = {layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in
range(num_down_blocks)}
# Retrieves the keys for the decoder up blocks only
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
up_blocks = {layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)}
up_blocks = {layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in
range(num_up_blocks)}
for i in range(num_down_blocks):
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
@@ -424,9 +762,11 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
f"encoder.down.{i}.downsample.conv.weight"
)
mapping[f"encoder.down.{i}.downsample.conv.weight"] = f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
f"encoder.down.{i}.downsample.conv.bias"
)
mapping[f"encoder.down.{i}.downsample.conv.bias"] = f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
@@ -449,15 +789,18 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
for i in range(num_up_blocks):
block_id = num_up_blocks - 1 - i
resnets = [key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key]
resnets = [key for key in up_blocks[block_id] if
f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key]
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
f"decoder.up.{block_id}.upsample.conv.weight"
]
mapping[f"decoder.up.{block_id}.upsample.conv.weight"] = f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
f"decoder.up.{block_id}.upsample.conv.bias"
]
mapping[f"decoder.up.{block_id}.upsample.conv.bias"] = f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
@@ -548,7 +891,7 @@ def convert_ldm_clip_checkpoint_v1(checkpoint):
text_model_dict = {}
for key in keys:
if key.startswith("cond_stage_model.transformer"):
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key]
return text_model_dict
@@ -677,36 +1020,36 @@ def convert_unet_state_dict_to_sd(v2, unet_state_dict):
for j in range(2):
# loop over resnets/attentions for downblocks
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i < 3:
# no attention layers in down_blocks.3
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
for j in range(3):
# loop over resnets/attentions for upblocks
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
if i > 0:
# no attention layers in up_blocks.0
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
if i < 3:
# no downsample in down_blocks.3
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
# no upsample in up_blocks.3
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}."
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
hf_mid_atn_prefix = "mid_block.attentions.0."
@@ -715,7 +1058,7 @@ def convert_unet_state_dict_to_sd(v2, unet_state_dict):
for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}."
sd_mid_res_prefix = f"middle_block.{2*j}."
sd_mid_res_prefix = f"middle_block.{2 * j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
# buyer beware: this is a *brittle* function,
@@ -772,20 +1115,20 @@ def convert_vae_state_dict(vae_state_dict):
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"up.{3-i}.upsample."
sd_upsample_prefix = f"up.{3 - i}.upsample."
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
# up_blocks have three resnets
# also, up blocks in hf are numbered in reverse from sd
for j in range(3):
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
sd_up_prefix = f"decoder.up.{3 - i}.block.{j}."
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
# this part accounts for mid blocks in both the encoder and the decoder
for i in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{i}."
sd_mid_res_prefix = f"mid.block_{i+1}."
sd_mid_res_prefix = f"mid.block_{i + 1}."
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
vae_conversion_map_attn = [
@@ -850,7 +1193,7 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"):
for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
for key in state_dict.keys():
if key.startswith(rep_from):
new_key = rep_to + key[len(rep_from) :]
new_key = rep_to + key[len(rep_from):]
key_reps.append((key, new_key))
for key, new_key in key_reps:
@@ -861,7 +1204,8 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"):
# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None, unet_use_linear_projection_in_v2=False):
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None,
unet_use_linear_projection_in_v2=False):
_, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device)
# Convert the UNet2DConditionModel model.
@@ -990,7 +1334,8 @@ def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=Fals
return new_sd
def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None):
def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None,
vae=None):
if ckpt_path is not None:
# epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
@@ -1059,7 +1404,8 @@ def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_p
return key_count
def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False):
def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None,
use_safetensors=False):
if pretrained_model_name_or_path is None:
# load default settings for v1/v2
if v2:
@@ -1177,4 +1523,4 @@ if __name__ == "__main__":
for ar in aspect_ratios:
if ar in ars:
print("error! duplicate ar:", ar)
ars.add(ar)
ars.add(ar)

23
toolkit/losses.py Normal file
View File

@@ -0,0 +1,23 @@
import torch
def total_variation(image):
"""
Compute normalized total variation.
Inputs:
- image: PyTorch Variable of shape (N, C, H, W)
Returns:
- TV: total variation normalized by the number of elements
"""
n_elements = image.shape[1] * image.shape[2] * image.shape[3]
return ((torch.sum(torch.abs(image[:, :, :, :-1] - image[:, :, :, 1:])) +
torch.sum(torch.abs(image[:, :, :-1, :] - image[:, :, 1:, :]))) / n_elements)
class ComparativeTotalVariation(torch.nn.Module):
"""
Compute the comparative loss in tv between two images. to match their tv
"""
def forward(self, pred, target):
return torch.abs(total_variation(pred) - total_variation(target))

View File

@@ -127,11 +127,12 @@ class Normalization(nn.Module):
def get_style_model_and_losses(
single_target=False,
device='cuda' if torch.cuda.is_available() else 'cpu'
device='cuda' if torch.cuda.is_available() else 'cpu',
output_layer_name=None,
):
# content_layers = ['conv_4']
# style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
content_layers = ['conv3_2', 'conv4_2']
content_layers = ['conv4_2']
style_layers = ['conv2_1', 'conv3_1', 'conv4_1']
cnn = models.vgg19(pretrained=True).features.to(device).eval()
# normalization module
@@ -150,6 +151,8 @@ def get_style_model_and_losses(
block = 1
children = list(cnn.children())
output_layer = None
for layer in children:
if isinstance(layer, nn.Conv2d):
i += 1
@@ -184,11 +187,16 @@ def get_style_model_and_losses(
model.add_module("style_loss_{}_{}".format(block, i), style_loss)
style_losses.append(style_loss)
if output_layer_name is not None and name == output_layer_name:
output_layer = layer
# now we trim off the layers after the last content and style losses
for i in range(len(model) - 1, -1, -1):
if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss):
break
if output_layer_name is not None and model[i].name == output_layer_name:
break
model = model[:(i + 1)]
return model, style_losses, content_losses
return model, style_losses, content_losses, output_layer