Handle conversions back to ldm for saving

This commit is contained in:
Jaret Burkett
2023-07-18 19:34:35 -06:00
parent 17c13eef88
commit 8d8229dfc0
5 changed files with 586 additions and 61 deletions

View File

@@ -1,4 +1,5 @@
import copy import copy
import glob
import os import os
import time import time
from collections import OrderedDict from collections import OrderedDict
@@ -12,8 +13,9 @@ from torch import nn
from torchvision.transforms import transforms from torchvision.transforms import transforms
from jobs.process import BaseTrainProcess from jobs.process import BaseTrainProcess
from toolkit.kohya_model_util import load_vae from toolkit.kohya_model_util import load_vae, convert_diffusers_back_to_ldm
from toolkit.data_loader import ImageDataset from toolkit.data_loader import ImageDataset
from toolkit.losses import ComparativeTotalVariation
from toolkit.metadata import get_meta_for_safetensors from toolkit.metadata import get_meta_for_safetensors
from toolkit.style import get_style_model_and_losses from toolkit.style import get_style_model_and_losses
from toolkit.train_tools import get_torch_dtype from toolkit.train_tools import get_torch_dtype
@@ -57,7 +59,7 @@ class TrainVAEProcess(BaseTrainProcess):
self.content_weight = self.get_conf('content_weight', 0) self.content_weight = self.get_conf('content_weight', 0)
self.kld_weight = self.get_conf('kld_weight', 0) self.kld_weight = self.get_conf('kld_weight', 0)
self.mse_weight = self.get_conf('mse_weight', 1e0) self.mse_weight = self.get_conf('mse_weight', 1e0)
self.tv_weight = self.get_conf('tv_weight', 1e0)
self.blocks_to_train = self.get_conf('blocks_to_train', ['all']) self.blocks_to_train = self.get_conf('blocks_to_train', ['all'])
self.writer = self.job.writer self.writer = self.job.writer
@@ -114,7 +116,7 @@ class TrainVAEProcess(BaseTrainProcess):
def setup_vgg19(self): def setup_vgg19(self):
if self.vgg_19 is None: if self.vgg_19 is None:
self.vgg_19, self.style_losses, self.content_losses = get_style_model_and_losses( self.vgg_19, self.style_losses, self.content_losses, output = get_style_model_and_losses(
single_target=True, device=self.device) single_target=True, device=self.device)
self.vgg_19.requires_grad_(False) self.vgg_19.requires_grad_(False)
@@ -149,6 +151,15 @@ class TrainVAEProcess(BaseTrainProcess):
else: else:
return torch.tensor(0.0, device=self.device) return torch.tensor(0.0, device=self.device)
def get_tv_loss(self, pred, target):
if self.tv_weight > 0:
get_tv_loss = ComparativeTotalVariation()
loss = get_tv_loss(pred, target)
return loss
else:
return torch.tensor(0.0, device=self.device)
def save(self, step=None): def save(self, step=None):
if not os.path.exists(self.save_root): if not os.path.exists(self.save_root):
os.makedirs(self.save_root, exist_ok=True) os.makedirs(self.save_root, exist_ok=True)
@@ -162,7 +173,7 @@ class TrainVAEProcess(BaseTrainProcess):
# prepare meta # prepare meta
save_meta = get_meta_for_safetensors(self.meta, self.job.name) save_meta = get_meta_for_safetensors(self.meta, self.job.name)
state_dict = self.vae.state_dict() state_dict = convert_diffusers_back_to_ldm(self.vae)
for key in list(state_dict.keys()): for key in list(state_dict.keys()):
v = state_dict[key] v = state_dict[key]
@@ -219,6 +230,30 @@ class TrainVAEProcess(BaseTrainProcess):
filename = f"{seconds_since_epoch}{step_num}_{i_str}.png" filename = f"{seconds_since_epoch}{step_num}_{i_str}.png"
output_img.save(os.path.join(sample_folder, filename)) output_img.save(os.path.join(sample_folder, filename))
def load_vae(self):
path_to_load = self.vae_path
# see if we have a checkpoint in out output to resume from
self.print(f"Looking for latest checkpoint in {self.save_root}")
files = glob.glob(os.path.join(self.save_root, f"{self.job.name}*.safetensors"))
if files and len(files) > 0:
latest_file = max(files, key=os.path.getmtime)
print(f" - Latest checkpoint is: {latest_file}")
path_to_load = latest_file
# todo update step and epoch count
else:
self.print(f" - No checkpoint found, starting from scratch")
# load vae
self.print(f"Loading VAE")
self.print(f" - Loading VAE: {path_to_load}")
if self.vae is None:
self.vae = load_vae(path_to_load, dtype=self.torch_dtype)
# set decoder to train
self.vae.to(self.device, dtype=self.torch_dtype)
self.vae.requires_grad_(False)
self.vae.eval()
self.vae.decoder.train()
def run(self): def run(self):
super().run() super().run()
self.load_datasets() self.load_datasets()
@@ -241,16 +276,7 @@ class TrainVAEProcess(BaseTrainProcess):
self.print(f" - Max steps: {self.max_steps}") self.print(f" - Max steps: {self.max_steps}")
# load vae # load vae
self.print(f"Loading VAE") self.load_vae()
self.print(f" - Loading VAE: {self.vae_path}")
if self.vae is None:
self.vae = load_vae(self.vae_path, dtype=self.torch_dtype)
# set decoder to train
self.vae.to(self.device, dtype=self.torch_dtype)
self.vae.requires_grad_(False)
self.vae.eval()
self.vae.decoder.train()
params = [] params = []
@@ -260,18 +286,22 @@ class TrainVAEProcess(BaseTrainProcess):
train_all = 'all' in self.blocks_to_train train_all = 'all' in self.blocks_to_train
# mid_block if train_all:
if train_all or 'mid_block' in self.blocks_to_train: params = list(self.vae.decoder.parameters())
params += list(self.vae.decoder.mid_block.parameters()) self.vae.decoder.requires_grad_(True)
self.vae.decoder.mid_block.requires_grad_(True) else:
# up_blocks # mid_block
if train_all or 'up_blocks' in self.blocks_to_train: if train_all or 'mid_block' in self.blocks_to_train:
params += list(self.vae.decoder.up_blocks.parameters()) params += list(self.vae.decoder.mid_block.parameters())
self.vae.decoder.up_blocks.requires_grad_(True) self.vae.decoder.mid_block.requires_grad_(True)
# conv_out (single conv layer output) # up_blocks
if train_all or 'conv_out' in self.blocks_to_train: if train_all or 'up_blocks' in self.blocks_to_train:
params += list(self.vae.decoder.conv_out.parameters()) params += list(self.vae.decoder.up_blocks.parameters())
self.vae.decoder.conv_out.requires_grad_(True) self.vae.decoder.up_blocks.requires_grad_(True)
# conv_out (single conv layer output)
if train_all or 'conv_out' in self.blocks_to_train:
params += list(self.vae.decoder.conv_out.parameters())
self.vae.decoder.conv_out.requires_grad_(True)
if self.style_weight > 0 or self.content_weight > 0: if self.style_weight > 0 or self.content_weight > 0:
self.setup_vgg19() self.setup_vgg19()
@@ -305,7 +335,8 @@ class TrainVAEProcess(BaseTrainProcess):
"style": [], "style": [],
"content": [], "content": [],
"mse": [], "mse": [],
"kl": [] "kl": [],
"tv": [],
}) })
epoch_losses = copy.deepcopy(blank_losses) epoch_losses = copy.deepcopy(blank_losses)
log_losses = copy.deepcopy(blank_losses) log_losses = copy.deepcopy(blank_losses)
@@ -337,8 +368,9 @@ class TrainVAEProcess(BaseTrainProcess):
content_loss = self.get_content_loss() * self.content_weight content_loss = self.get_content_loss() * self.content_weight
kld_loss = self.get_kld_loss(mu, logvar) * self.kld_weight kld_loss = self.get_kld_loss(mu, logvar) * self.kld_weight
mse_loss = self.get_mse_loss(pred, batch) * self.mse_weight mse_loss = self.get_mse_loss(pred, batch) * self.mse_weight
tv_loss = self.get_tv_loss(pred, batch) * self.tv_weight
loss = style_loss + content_loss + kld_loss + mse_loss loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss
# Backward pass and optimization # Backward pass and optimization
optimizer.zero_grad() optimizer.zero_grad()
@@ -358,6 +390,8 @@ class TrainVAEProcess(BaseTrainProcess):
loss_string += f" kld: {kld_loss.item():.2e}" loss_string += f" kld: {kld_loss.item():.2e}"
if self.mse_weight > 0: if self.mse_weight > 0:
loss_string += f" mse: {mse_loss.item():.2e}" loss_string += f" mse: {mse_loss.item():.2e}"
if self.tv_weight > 0:
loss_string += f" tv: {tv_loss.item():.2e}"
learning_rate = optimizer.param_groups[0]['lr'] learning_rate = optimizer.param_groups[0]['lr']
self.progress_bar.set_postfix_str(f"LR: {learning_rate:.2e} {loss_string}") self.progress_bar.set_postfix_str(f"LR: {learning_rate:.2e} {loss_string}")
@@ -369,12 +403,14 @@ class TrainVAEProcess(BaseTrainProcess):
epoch_losses["content"].append(content_loss.item()) epoch_losses["content"].append(content_loss.item())
epoch_losses["mse"].append(mse_loss.item()) epoch_losses["mse"].append(mse_loss.item())
epoch_losses["kl"].append(kld_loss.item()) epoch_losses["kl"].append(kld_loss.item())
epoch_losses["tv"].append(tv_loss.item())
log_losses["total"].append(loss_value) log_losses["total"].append(loss_value)
log_losses["style"].append(style_loss.item()) log_losses["style"].append(style_loss.item())
log_losses["content"].append(content_loss.item()) log_losses["content"].append(content_loss.item())
log_losses["mse"].append(mse_loss.item()) log_losses["mse"].append(mse_loss.item())
log_losses["kl"].append(kld_loss.item()) log_losses["kl"].append(kld_loss.item())
log_losses["tv"].append(tv_loss.item())
if step != 0: if step != 0:
if self.sample_every and step % self.sample_every == 0: if self.sample_every and step % self.sample_every == 0:

112
testing/test_vae_cycle.py Normal file
View File

@@ -0,0 +1,112 @@
import os
import torch
from safetensors.torch import load_file
from collections import OrderedDict
from toolkit.kohya_model_util import load_vae, convert_diffusers_back_to_ldm, vae_keys_squished_on_diffusers
import json
# this was just used to match the vae keys to the diffusers keys
# you probably wont need this. Unless they change them.... again... again
# on second thought, you probably will
device = torch.device('cpu')
dtype = torch.float32
vae_path = '/mnt/Models/stable-diffusion/models/VAE/vae-ft-mse-840000-ema-pruned/vae-ft-mse-840000-ema-pruned.safetensors'
find_matches = False
state_dict_ldm = load_file(vae_path)
diffusers_vae = load_vae(vae_path, dtype=torch.float32).to(device)
ldm_keys = state_dict_ldm.keys()
matched_keys = {}
duplicated_keys = {
}
if find_matches:
# find values that match with a very low mse
for ldm_key in ldm_keys:
ldm_value = state_dict_ldm[ldm_key]
for diffusers_key in list(diffusers_vae.state_dict().keys()):
diffusers_value = diffusers_vae.state_dict()[diffusers_key]
if diffusers_key in vae_keys_squished_on_diffusers:
diffusers_value = diffusers_value.clone().unsqueeze(-1).unsqueeze(-1)
# if they are not same shape, skip
if ldm_value.shape != diffusers_value.shape:
continue
mse = torch.nn.functional.mse_loss(ldm_value, diffusers_value)
if mse < 1e-6:
if ldm_key in list(matched_keys.keys()):
print(f'{ldm_key} already matched to {matched_keys[ldm_key]}')
if ldm_key in duplicated_keys:
duplicated_keys[ldm_key].append(diffusers_key)
else:
duplicated_keys[ldm_key] = [diffusers_key]
continue
matched_keys[ldm_key] = diffusers_key
is_matched = True
break
print(f'Found {len(matched_keys)} matches')
dif_to_ldm_state_dict = convert_diffusers_back_to_ldm(diffusers_vae)
dif_to_ldm_state_dict_keys = list(dif_to_ldm_state_dict.keys())
keys_in_both = []
keys_not_in_diffusers = []
for key in ldm_keys:
if key not in dif_to_ldm_state_dict_keys:
keys_not_in_diffusers.append(key)
keys_not_in_ldm = []
for key in dif_to_ldm_state_dict_keys:
if key not in ldm_keys:
keys_not_in_ldm.append(key)
keys_in_both = []
for key in ldm_keys:
if key in dif_to_ldm_state_dict_keys:
keys_in_both.append(key)
# sort them
keys_not_in_diffusers.sort()
keys_not_in_ldm.sort()
keys_in_both.sort()
# print(f'Keys in LDM but not in Diffusers: {len(keys_not_in_diffusers)}{keys_not_in_diffusers}')
# print(f'Keys in Diffusers but not in LDM: {len(keys_not_in_ldm)}{keys_not_in_ldm}')
# print(f'Keys in both: {len(keys_in_both)}{keys_in_both}')
json_data = {
"both": keys_in_both,
"ldm": keys_not_in_diffusers,
"diffusers": keys_not_in_ldm
}
json_data = json.dumps(json_data, indent=4)
remaining_diffusers_values = OrderedDict()
for key in keys_not_in_ldm:
remaining_diffusers_values[key] = dif_to_ldm_state_dict[key]
# print(remaining_diffusers_values.keys())
remaining_ldm_values = OrderedDict()
for key in keys_not_in_diffusers:
remaining_ldm_values[key] = state_dict_ldm[key]
# print(json_data)
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
json_save_path = os.path.join(project_root, 'config', 'keys.json')
json_matched_save_path = os.path.join(project_root, 'config', 'matched.json')
json_duped_save_path = os.path.join(project_root, 'config', 'duped.json')
with open(json_save_path, 'w') as f:
f.write(json_data)
if find_matches:
with open(json_matched_save_path, 'w') as f:
f.write(json.dumps(matched_keys, indent=4))
with open(json_duped_save_path, 'w') as f:
f.write(json.dumps(duplicated_keys, indent=4))

View File

@@ -1,16 +1,19 @@
# mostly from https://github.com/kohya-ss/sd-scripts/blob/main/library/model_util.py # mostly from https://github.com/kohya-ss/sd-scripts/blob/main/library/model_util.py
# I am infinitely grateful to @kohya-ss for their amazing work in this field. # I am infinitely grateful to @kohya-ss for their amazing work in this field.
# This version is updated to handle the latest version of the diffusers library. # This version is updated to handle the latest version of the diffusers library.
import json
# v1: split from train_db_fixed.py. # v1: split from train_db_fixed.py.
# v2: support safetensors # v2: support safetensors
import math import math
import os import os
import re
import torch import torch
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from safetensors.torch import load_file, save_file from safetensors.torch import load_file, save_file
from collections import OrderedDict
# DiffUsers版StableDiffusionのモデルパラメータ # DiffUsers版StableDiffusionのモデルパラメータ
NUM_TRAIN_TIMESTEPS = 1000 NUM_TRAIN_TIMESTEPS = 1000
@@ -151,7 +154,7 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
def assign_to_checkpoint( def assign_to_checkpoint(
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
): ):
""" """
This does the final conversion step: take locally converted weights and apply a global renaming This does the final conversion step: take locally converted weights and apply a global renaming
@@ -228,6 +231,7 @@ def linear_transformer_to_conv(checkpoint):
def convert_ldm_unet_checkpoint(v2, checkpoint, config): def convert_ldm_unet_checkpoint(v2, checkpoint, config):
mapping = {}
""" """
Takes a state dict and a config, and returns a converted checkpoint. Takes a state dict and a config, and returns a converted checkpoint.
""" """
@@ -258,33 +262,40 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
# Retrieves the keys for the input blocks only # Retrieves the keys for the input blocks only
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
input_blocks = { input_blocks = {
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] for layer_id in range(num_input_blocks) layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] for layer_id in
range(num_input_blocks)
} }
# Retrieves the keys for the middle blocks only # Retrieves the keys for the middle blocks only
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
middle_blocks = { middle_blocks = {
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key] for layer_id in range(num_middle_blocks) layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key] for layer_id in
range(num_middle_blocks)
} }
# Retrieves the keys for the output blocks only # Retrieves the keys for the output blocks only
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
output_blocks = { output_blocks = {
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] for layer_id in range(num_output_blocks) layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] for layer_id in
range(num_output_blocks)
} }
for i in range(1, num_input_blocks): for i in range(1, num_input_blocks):
block_id = (i - 1) // (config["layers_per_block"] + 1) block_id = (i - 1) // (config["layers_per_block"] + 1)
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key] resnets = [key for key in input_blocks[i] if
f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key]
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
if f"input_blocks.{i}.0.op.weight" in unet_state_dict: if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
f"input_blocks.{i}.0.op.weight" f"input_blocks.{i}.0.op.weight"
) )
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(f"input_blocks.{i}.0.op.bias") mapping[f'input_blocks.{i}.0.op.weight'] = f"down_blocks.{block_id}.downsamplers.0.conv.weight"
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
f"input_blocks.{i}.0.op.bias")
mapping[f'input_blocks.{i}.0.op.bias'] = f"down_blocks.{block_id}.downsamplers.0.conv.bias"
paths = renew_resnet_paths(resnets) paths = renew_resnet_paths(resnets)
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
@@ -293,7 +304,8 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
if len(attentions): if len(attentions):
paths = renew_attention_paths(attentions) paths = renew_attention_paths(attentions)
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path],
config=config)
resnet_0 = middle_blocks[0] resnet_0 = middle_blocks[0]
attentions = middle_blocks[1] attentions = middle_blocks[1]
@@ -307,7 +319,8 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
attentions_paths = renew_attention_paths(attentions) attentions_paths = renew_attention_paths(attentions)
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
assign_to_checkpoint(attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) assign_to_checkpoint(attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path],
config=config)
for i in range(num_output_blocks): for i in range(num_output_blocks):
block_id = i // (config["layers_per_block"] + 1) block_id = i // (config["layers_per_block"] + 1)
@@ -330,7 +343,8 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
paths = renew_resnet_paths(resnets) paths = renew_resnet_paths(resnets)
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path],
config=config)
# オリジナル: # オリジナル:
# if ["conv.weight", "conv.bias"] in output_block_list.values(): # if ["conv.weight", "conv.bias"] in output_block_list.values():
@@ -359,7 +373,8 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
"old": f"output_blocks.{i}.1", "old": f"output_blocks.{i}.1",
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
} }
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path],
config=config)
else: else:
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
for path in resnet_0_paths: for path in resnet_0_paths:
@@ -373,10 +388,326 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
if v2 and not config.get('use_linear_projection', False): if v2 and not config.get('use_linear_projection', False):
linear_transformer_to_conv(new_checkpoint) linear_transformer_to_conv(new_checkpoint)
# print("mapping: ", json.dumps(mapping, indent=4))
return new_checkpoint return new_checkpoint
# ldm key: diffusers key
vae_ldm_to_diffusers_dict = {
"decoder.conv_in.bias": "decoder.conv_in.bias",
"decoder.conv_in.weight": "decoder.conv_in.weight",
"decoder.conv_out.bias": "decoder.conv_out.bias",
"decoder.conv_out.weight": "decoder.conv_out.weight",
"decoder.mid.attn_1.k.bias": "decoder.mid_block.attentions.0.to_k.bias",
"decoder.mid.attn_1.k.weight": "decoder.mid_block.attentions.0.to_k.weight",
"decoder.mid.attn_1.norm.bias": "decoder.mid_block.attentions.0.group_norm.bias",
"decoder.mid.attn_1.norm.weight": "decoder.mid_block.attentions.0.group_norm.weight",
"decoder.mid.attn_1.proj_out.bias": "decoder.mid_block.attentions.0.to_out.0.bias",
"decoder.mid.attn_1.proj_out.weight": "decoder.mid_block.attentions.0.to_out.0.weight",
"decoder.mid.attn_1.q.bias": "decoder.mid_block.attentions.0.to_q.bias",
"decoder.mid.attn_1.q.weight": "decoder.mid_block.attentions.0.to_q.weight",
"decoder.mid.attn_1.v.bias": "decoder.mid_block.attentions.0.to_v.bias",
"decoder.mid.attn_1.v.weight": "decoder.mid_block.attentions.0.to_v.weight",
"decoder.mid.block_1.conv1.bias": "decoder.mid_block.resnets.0.conv1.bias",
"decoder.mid.block_1.conv1.weight": "decoder.mid_block.resnets.0.conv1.weight",
"decoder.mid.block_1.conv2.bias": "decoder.mid_block.resnets.0.conv2.bias",
"decoder.mid.block_1.conv2.weight": "decoder.mid_block.resnets.0.conv2.weight",
"decoder.mid.block_1.norm1.bias": "decoder.mid_block.resnets.0.norm1.bias",
"decoder.mid.block_1.norm1.weight": "decoder.mid_block.resnets.0.norm1.weight",
"decoder.mid.block_1.norm2.bias": "decoder.mid_block.resnets.0.norm2.bias",
"decoder.mid.block_1.norm2.weight": "decoder.mid_block.resnets.0.norm2.weight",
"decoder.mid.block_2.conv1.bias": "decoder.mid_block.resnets.1.conv1.bias",
"decoder.mid.block_2.conv1.weight": "decoder.mid_block.resnets.1.conv1.weight",
"decoder.mid.block_2.conv2.bias": "decoder.mid_block.resnets.1.conv2.bias",
"decoder.mid.block_2.conv2.weight": "decoder.mid_block.resnets.1.conv2.weight",
"decoder.mid.block_2.norm1.bias": "decoder.mid_block.resnets.1.norm1.bias",
"decoder.mid.block_2.norm1.weight": "decoder.mid_block.resnets.1.norm1.weight",
"decoder.mid.block_2.norm2.bias": "decoder.mid_block.resnets.1.norm2.bias",
"decoder.mid.block_2.norm2.weight": "decoder.mid_block.resnets.1.norm2.weight",
"decoder.norm_out.bias": "decoder.conv_norm_out.bias",
"decoder.norm_out.weight": "decoder.conv_norm_out.weight",
"decoder.up.0.block.0.conv1.bias": "decoder.up_blocks.3.resnets.0.conv1.bias",
"decoder.up.0.block.0.conv1.weight": "decoder.up_blocks.3.resnets.0.conv1.weight",
"decoder.up.0.block.0.conv2.bias": "decoder.up_blocks.3.resnets.0.conv2.bias",
"decoder.up.0.block.0.conv2.weight": "decoder.up_blocks.3.resnets.0.conv2.weight",
"decoder.up.0.block.0.nin_shortcut.bias": "decoder.up_blocks.3.resnets.0.conv_shortcut.bias",
"decoder.up.0.block.0.nin_shortcut.weight": "decoder.up_blocks.3.resnets.0.conv_shortcut.weight",
"decoder.up.0.block.0.norm1.bias": "decoder.up_blocks.3.resnets.0.norm1.bias",
"decoder.up.0.block.0.norm1.weight": "decoder.up_blocks.3.resnets.0.norm1.weight",
"decoder.up.0.block.0.norm2.bias": "decoder.up_blocks.3.resnets.0.norm2.bias",
"decoder.up.0.block.0.norm2.weight": "decoder.up_blocks.3.resnets.0.norm2.weight",
"decoder.up.0.block.1.conv1.bias": "decoder.up_blocks.3.resnets.1.conv1.bias",
"decoder.up.0.block.1.conv1.weight": "decoder.up_blocks.3.resnets.1.conv1.weight",
"decoder.up.0.block.1.conv2.bias": "decoder.up_blocks.3.resnets.1.conv2.bias",
"decoder.up.0.block.1.conv2.weight": "decoder.up_blocks.3.resnets.1.conv2.weight",
"decoder.up.0.block.1.norm1.bias": "decoder.up_blocks.3.resnets.1.norm1.bias",
"decoder.up.0.block.1.norm1.weight": "decoder.up_blocks.3.resnets.1.norm1.weight",
"decoder.up.0.block.1.norm2.bias": "decoder.up_blocks.3.resnets.1.norm2.bias",
"decoder.up.0.block.1.norm2.weight": "decoder.up_blocks.3.resnets.1.norm2.weight",
"decoder.up.0.block.2.conv1.bias": "decoder.up_blocks.3.resnets.2.conv1.bias",
"decoder.up.0.block.2.conv1.weight": "decoder.up_blocks.3.resnets.2.conv1.weight",
"decoder.up.0.block.2.conv2.bias": "decoder.up_blocks.3.resnets.2.conv2.bias",
"decoder.up.0.block.2.conv2.weight": "decoder.up_blocks.3.resnets.2.conv2.weight",
"decoder.up.0.block.2.norm1.bias": "decoder.up_blocks.3.resnets.2.norm1.bias",
"decoder.up.0.block.2.norm1.weight": "decoder.up_blocks.3.resnets.2.norm1.weight",
"decoder.up.0.block.2.norm2.bias": "decoder.up_blocks.3.resnets.2.norm2.bias",
"decoder.up.0.block.2.norm2.weight": "decoder.up_blocks.3.resnets.2.norm2.weight",
"decoder.up.1.block.0.conv1.bias": "decoder.up_blocks.2.resnets.0.conv1.bias",
"decoder.up.1.block.0.conv1.weight": "decoder.up_blocks.2.resnets.0.conv1.weight",
"decoder.up.1.block.0.conv2.bias": "decoder.up_blocks.2.resnets.0.conv2.bias",
"decoder.up.1.block.0.conv2.weight": "decoder.up_blocks.2.resnets.0.conv2.weight",
"decoder.up.1.block.0.nin_shortcut.bias": "decoder.up_blocks.2.resnets.0.conv_shortcut.bias",
"decoder.up.1.block.0.nin_shortcut.weight": "decoder.up_blocks.2.resnets.0.conv_shortcut.weight",
"decoder.up.1.block.0.norm1.bias": "decoder.up_blocks.2.resnets.0.norm1.bias",
"decoder.up.1.block.0.norm1.weight": "decoder.up_blocks.2.resnets.0.norm1.weight",
"decoder.up.1.block.0.norm2.bias": "decoder.up_blocks.2.resnets.0.norm2.bias",
"decoder.up.1.block.0.norm2.weight": "decoder.up_blocks.2.resnets.0.norm2.weight",
"decoder.up.1.block.1.conv1.bias": "decoder.up_blocks.2.resnets.1.conv1.bias",
"decoder.up.1.block.1.conv1.weight": "decoder.up_blocks.2.resnets.1.conv1.weight",
"decoder.up.1.block.1.conv2.bias": "decoder.up_blocks.2.resnets.1.conv2.bias",
"decoder.up.1.block.1.conv2.weight": "decoder.up_blocks.2.resnets.1.conv2.weight",
"decoder.up.1.block.1.norm1.bias": "decoder.up_blocks.2.resnets.1.norm1.bias",
"decoder.up.1.block.1.norm1.weight": "decoder.up_blocks.2.resnets.1.norm1.weight",
"decoder.up.1.block.1.norm2.bias": "decoder.up_blocks.2.resnets.1.norm2.bias",
"decoder.up.1.block.1.norm2.weight": "decoder.up_blocks.2.resnets.1.norm2.weight",
"decoder.up.1.block.2.conv1.bias": "decoder.up_blocks.2.resnets.2.conv1.bias",
"decoder.up.1.block.2.conv1.weight": "decoder.up_blocks.2.resnets.2.conv1.weight",
"decoder.up.1.block.2.conv2.bias": "decoder.up_blocks.2.resnets.2.conv2.bias",
"decoder.up.1.block.2.conv2.weight": "decoder.up_blocks.2.resnets.2.conv2.weight",
"decoder.up.1.block.2.norm1.bias": "decoder.up_blocks.2.resnets.2.norm1.bias",
"decoder.up.1.block.2.norm1.weight": "decoder.up_blocks.2.resnets.2.norm1.weight",
"decoder.up.1.block.2.norm2.bias": "decoder.up_blocks.2.resnets.2.norm2.bias",
"decoder.up.1.block.2.norm2.weight": "decoder.up_blocks.2.resnets.2.norm2.weight",
"decoder.up.1.upsample.conv.bias": "decoder.up_blocks.2.upsamplers.0.conv.bias",
"decoder.up.1.upsample.conv.weight": "decoder.up_blocks.2.upsamplers.0.conv.weight",
"decoder.up.2.block.0.conv1.bias": "decoder.up_blocks.1.resnets.0.conv1.bias",
"decoder.up.2.block.0.conv1.weight": "decoder.up_blocks.1.resnets.0.conv1.weight",
"decoder.up.2.block.0.conv2.bias": "decoder.up_blocks.1.resnets.0.conv2.bias",
"decoder.up.2.block.0.conv2.weight": "decoder.up_blocks.1.resnets.0.conv2.weight",
"decoder.up.2.block.0.norm1.bias": "decoder.up_blocks.1.resnets.0.norm1.bias",
"decoder.up.2.block.0.norm1.weight": "decoder.up_blocks.1.resnets.0.norm1.weight",
"decoder.up.2.block.0.norm2.bias": "decoder.up_blocks.1.resnets.0.norm2.bias",
"decoder.up.2.block.0.norm2.weight": "decoder.up_blocks.1.resnets.0.norm2.weight",
"decoder.up.2.block.1.conv1.bias": "decoder.up_blocks.1.resnets.1.conv1.bias",
"decoder.up.2.block.1.conv1.weight": "decoder.up_blocks.1.resnets.1.conv1.weight",
"decoder.up.2.block.1.conv2.bias": "decoder.up_blocks.1.resnets.1.conv2.bias",
"decoder.up.2.block.1.conv2.weight": "decoder.up_blocks.1.resnets.1.conv2.weight",
"decoder.up.2.block.1.norm1.bias": "decoder.up_blocks.1.resnets.1.norm1.bias",
"decoder.up.2.block.1.norm1.weight": "decoder.up_blocks.1.resnets.1.norm1.weight",
"decoder.up.2.block.1.norm2.bias": "decoder.up_blocks.1.resnets.1.norm2.bias",
"decoder.up.2.block.1.norm2.weight": "decoder.up_blocks.1.resnets.1.norm2.weight",
"decoder.up.2.block.2.conv1.bias": "decoder.up_blocks.1.resnets.2.conv1.bias",
"decoder.up.2.block.2.conv1.weight": "decoder.up_blocks.1.resnets.2.conv1.weight",
"decoder.up.2.block.2.conv2.bias": "decoder.up_blocks.1.resnets.2.conv2.bias",
"decoder.up.2.block.2.conv2.weight": "decoder.up_blocks.1.resnets.2.conv2.weight",
"decoder.up.2.block.2.norm1.bias": "decoder.up_blocks.1.resnets.2.norm1.bias",
"decoder.up.2.block.2.norm1.weight": "decoder.up_blocks.1.resnets.2.norm1.weight",
"decoder.up.2.block.2.norm2.bias": "decoder.up_blocks.1.resnets.2.norm2.bias",
"decoder.up.2.block.2.norm2.weight": "decoder.up_blocks.1.resnets.2.norm2.weight",
"decoder.up.2.upsample.conv.bias": "decoder.up_blocks.1.upsamplers.0.conv.bias",
"decoder.up.2.upsample.conv.weight": "decoder.up_blocks.1.upsamplers.0.conv.weight",
"decoder.up.3.block.0.conv1.bias": "decoder.up_blocks.0.resnets.0.conv1.bias",
"decoder.up.3.block.0.conv1.weight": "decoder.up_blocks.0.resnets.0.conv1.weight",
"decoder.up.3.block.0.conv2.bias": "decoder.up_blocks.0.resnets.0.conv2.bias",
"decoder.up.3.block.0.conv2.weight": "decoder.up_blocks.0.resnets.0.conv2.weight",
"decoder.up.3.block.0.norm1.bias": "decoder.up_blocks.0.resnets.0.norm1.bias",
"decoder.up.3.block.0.norm1.weight": "decoder.up_blocks.0.resnets.0.norm1.weight",
"decoder.up.3.block.0.norm2.bias": "decoder.up_blocks.0.resnets.0.norm2.bias",
"decoder.up.3.block.0.norm2.weight": "decoder.up_blocks.0.resnets.0.norm2.weight",
"decoder.up.3.block.1.conv1.bias": "decoder.up_blocks.0.resnets.1.conv1.bias",
"decoder.up.3.block.1.conv1.weight": "decoder.up_blocks.0.resnets.1.conv1.weight",
"decoder.up.3.block.1.conv2.bias": "decoder.up_blocks.0.resnets.1.conv2.bias",
"decoder.up.3.block.1.conv2.weight": "decoder.up_blocks.0.resnets.1.conv2.weight",
"decoder.up.3.block.1.norm1.bias": "decoder.up_blocks.0.resnets.1.norm1.bias",
"decoder.up.3.block.1.norm1.weight": "decoder.up_blocks.0.resnets.1.norm1.weight",
"decoder.up.3.block.1.norm2.bias": "decoder.up_blocks.0.resnets.1.norm2.bias",
"decoder.up.3.block.1.norm2.weight": "decoder.up_blocks.0.resnets.1.norm2.weight",
"decoder.up.3.block.2.conv1.bias": "decoder.up_blocks.0.resnets.2.conv1.bias",
"decoder.up.3.block.2.conv1.weight": "decoder.up_blocks.0.resnets.2.conv1.weight",
"decoder.up.3.block.2.conv2.bias": "decoder.up_blocks.0.resnets.2.conv2.bias",
"decoder.up.3.block.2.conv2.weight": "decoder.up_blocks.0.resnets.2.conv2.weight",
"decoder.up.3.block.2.norm1.bias": "decoder.up_blocks.0.resnets.2.norm1.bias",
"decoder.up.3.block.2.norm1.weight": "decoder.up_blocks.0.resnets.2.norm1.weight",
"decoder.up.3.block.2.norm2.bias": "decoder.up_blocks.0.resnets.2.norm2.bias",
"decoder.up.3.block.2.norm2.weight": "decoder.up_blocks.0.resnets.2.norm2.weight",
"decoder.up.3.upsample.conv.bias": "decoder.up_blocks.0.upsamplers.0.conv.bias",
"decoder.up.3.upsample.conv.weight": "decoder.up_blocks.0.upsamplers.0.conv.weight",
"encoder.conv_in.bias": "encoder.conv_in.bias",
"encoder.conv_in.weight": "encoder.conv_in.weight",
"encoder.conv_out.bias": "encoder.conv_out.bias",
"encoder.conv_out.weight": "encoder.conv_out.weight",
"encoder.down.0.block.0.conv1.bias": "encoder.down_blocks.0.resnets.0.conv1.bias",
"encoder.down.0.block.0.conv1.weight": "encoder.down_blocks.0.resnets.0.conv1.weight",
"encoder.down.0.block.0.conv2.bias": "encoder.down_blocks.0.resnets.0.conv2.bias",
"encoder.down.0.block.0.conv2.weight": "encoder.down_blocks.0.resnets.0.conv2.weight",
"encoder.down.0.block.0.norm1.bias": "encoder.down_blocks.0.resnets.0.norm1.bias",
"encoder.down.0.block.0.norm1.weight": "encoder.down_blocks.0.resnets.0.norm1.weight",
"encoder.down.0.block.0.norm2.bias": "encoder.down_blocks.0.resnets.0.norm2.bias",
"encoder.down.0.block.0.norm2.weight": "encoder.down_blocks.0.resnets.0.norm2.weight",
"encoder.down.0.block.1.conv1.bias": "encoder.down_blocks.0.resnets.1.conv1.bias",
"encoder.down.0.block.1.conv1.weight": "encoder.down_blocks.0.resnets.1.conv1.weight",
"encoder.down.0.block.1.conv2.bias": "encoder.down_blocks.0.resnets.1.conv2.bias",
"encoder.down.0.block.1.conv2.weight": "encoder.down_blocks.0.resnets.1.conv2.weight",
"encoder.down.0.block.1.norm1.bias": "encoder.down_blocks.0.resnets.1.norm1.bias",
"encoder.down.0.block.1.norm1.weight": "encoder.down_blocks.0.resnets.1.norm1.weight",
"encoder.down.0.block.1.norm2.bias": "encoder.down_blocks.0.resnets.1.norm2.bias",
"encoder.down.0.block.1.norm2.weight": "encoder.down_blocks.0.resnets.1.norm2.weight",
"encoder.down.0.downsample.conv.bias": "encoder.down_blocks.0.downsamplers.0.conv.bias",
"encoder.down.0.downsample.conv.weight": "encoder.down_blocks.0.downsamplers.0.conv.weight",
"encoder.down.1.block.0.conv1.bias": "encoder.down_blocks.1.resnets.0.conv1.bias",
"encoder.down.1.block.0.conv1.weight": "encoder.down_blocks.1.resnets.0.conv1.weight",
"encoder.down.1.block.0.conv2.bias": "encoder.down_blocks.1.resnets.0.conv2.bias",
"encoder.down.1.block.0.conv2.weight": "encoder.down_blocks.1.resnets.0.conv2.weight",
"encoder.down.1.block.0.nin_shortcut.bias": "encoder.down_blocks.1.resnets.0.conv_shortcut.bias",
"encoder.down.1.block.0.nin_shortcut.weight": "encoder.down_blocks.1.resnets.0.conv_shortcut.weight",
"encoder.down.1.block.0.norm1.bias": "encoder.down_blocks.1.resnets.0.norm1.bias",
"encoder.down.1.block.0.norm1.weight": "encoder.down_blocks.1.resnets.0.norm1.weight",
"encoder.down.1.block.0.norm2.bias": "encoder.down_blocks.1.resnets.0.norm2.bias",
"encoder.down.1.block.0.norm2.weight": "encoder.down_blocks.1.resnets.0.norm2.weight",
"encoder.down.1.block.1.conv1.bias": "encoder.down_blocks.1.resnets.1.conv1.bias",
"encoder.down.1.block.1.conv1.weight": "encoder.down_blocks.1.resnets.1.conv1.weight",
"encoder.down.1.block.1.conv2.bias": "encoder.down_blocks.1.resnets.1.conv2.bias",
"encoder.down.1.block.1.conv2.weight": "encoder.down_blocks.1.resnets.1.conv2.weight",
"encoder.down.1.block.1.norm1.bias": "encoder.down_blocks.1.resnets.1.norm1.bias",
"encoder.down.1.block.1.norm1.weight": "encoder.down_blocks.1.resnets.1.norm1.weight",
"encoder.down.1.block.1.norm2.bias": "encoder.down_blocks.1.resnets.1.norm2.bias",
"encoder.down.1.block.1.norm2.weight": "encoder.down_blocks.1.resnets.1.norm2.weight",
"encoder.down.1.downsample.conv.bias": "encoder.down_blocks.1.downsamplers.0.conv.bias",
"encoder.down.1.downsample.conv.weight": "encoder.down_blocks.1.downsamplers.0.conv.weight",
"encoder.down.2.block.0.conv1.bias": "encoder.down_blocks.2.resnets.0.conv1.bias",
"encoder.down.2.block.0.conv1.weight": "encoder.down_blocks.2.resnets.0.conv1.weight",
"encoder.down.2.block.0.conv2.bias": "encoder.down_blocks.2.resnets.0.conv2.bias",
"encoder.down.2.block.0.conv2.weight": "encoder.down_blocks.2.resnets.0.conv2.weight",
"encoder.down.2.block.0.nin_shortcut.bias": "encoder.down_blocks.2.resnets.0.conv_shortcut.bias",
"encoder.down.2.block.0.nin_shortcut.weight": "encoder.down_blocks.2.resnets.0.conv_shortcut.weight",
"encoder.down.2.block.0.norm1.bias": "encoder.down_blocks.2.resnets.0.norm1.bias",
"encoder.down.2.block.0.norm1.weight": "encoder.down_blocks.2.resnets.0.norm1.weight",
"encoder.down.2.block.0.norm2.bias": "encoder.down_blocks.2.resnets.0.norm2.bias",
"encoder.down.2.block.0.norm2.weight": "encoder.down_blocks.2.resnets.0.norm2.weight",
"encoder.down.2.block.1.conv1.bias": "encoder.down_blocks.2.resnets.1.conv1.bias",
"encoder.down.2.block.1.conv1.weight": "encoder.down_blocks.2.resnets.1.conv1.weight",
"encoder.down.2.block.1.conv2.bias": "encoder.down_blocks.2.resnets.1.conv2.bias",
"encoder.down.2.block.1.conv2.weight": "encoder.down_blocks.2.resnets.1.conv2.weight",
"encoder.down.2.block.1.norm1.bias": "encoder.down_blocks.2.resnets.1.norm1.bias",
"encoder.down.2.block.1.norm1.weight": "encoder.down_blocks.2.resnets.1.norm1.weight",
"encoder.down.2.block.1.norm2.bias": "encoder.down_blocks.2.resnets.1.norm2.bias",
"encoder.down.2.block.1.norm2.weight": "encoder.down_blocks.2.resnets.1.norm2.weight",
"encoder.down.2.downsample.conv.bias": "encoder.down_blocks.2.downsamplers.0.conv.bias",
"encoder.down.2.downsample.conv.weight": "encoder.down_blocks.2.downsamplers.0.conv.weight",
"encoder.down.3.block.0.conv1.bias": "encoder.down_blocks.3.resnets.0.conv1.bias",
"encoder.down.3.block.0.conv1.weight": "encoder.down_blocks.3.resnets.0.conv1.weight",
"encoder.down.3.block.0.conv2.bias": "encoder.down_blocks.3.resnets.0.conv2.bias",
"encoder.down.3.block.0.conv2.weight": "encoder.down_blocks.3.resnets.0.conv2.weight",
"encoder.down.3.block.0.norm1.bias": "encoder.down_blocks.3.resnets.0.norm1.bias",
"encoder.down.3.block.0.norm1.weight": "encoder.down_blocks.3.resnets.0.norm1.weight",
"encoder.down.3.block.0.norm2.bias": "encoder.down_blocks.3.resnets.0.norm2.bias",
"encoder.down.3.block.0.norm2.weight": "encoder.down_blocks.3.resnets.0.norm2.weight",
"encoder.down.3.block.1.conv1.bias": "encoder.down_blocks.3.resnets.1.conv1.bias",
"encoder.down.3.block.1.conv1.weight": "encoder.down_blocks.3.resnets.1.conv1.weight",
"encoder.down.3.block.1.conv2.bias": "encoder.down_blocks.3.resnets.1.conv2.bias",
"encoder.down.3.block.1.conv2.weight": "encoder.down_blocks.3.resnets.1.conv2.weight",
"encoder.down.3.block.1.norm1.bias": "encoder.down_blocks.3.resnets.1.norm1.bias",
"encoder.down.3.block.1.norm1.weight": "encoder.down_blocks.3.resnets.1.norm1.weight",
"encoder.down.3.block.1.norm2.bias": "encoder.down_blocks.3.resnets.1.norm2.bias",
"encoder.down.3.block.1.norm2.weight": "encoder.down_blocks.3.resnets.1.norm2.weight",
"encoder.mid.attn_1.k.bias": "encoder.mid_block.attentions.0.to_k.bias",
"encoder.mid.attn_1.k.weight": "encoder.mid_block.attentions.0.to_k.weight",
"encoder.mid.attn_1.norm.bias": "encoder.mid_block.attentions.0.group_norm.bias",
"encoder.mid.attn_1.norm.weight": "encoder.mid_block.attentions.0.group_norm.weight",
"encoder.mid.attn_1.proj_out.bias": "encoder.mid_block.attentions.0.to_out.0.bias",
"encoder.mid.attn_1.proj_out.weight": "encoder.mid_block.attentions.0.to_out.0.weight",
"encoder.mid.attn_1.q.bias": "encoder.mid_block.attentions.0.to_q.bias",
"encoder.mid.attn_1.q.weight": "encoder.mid_block.attentions.0.to_q.weight",
"encoder.mid.attn_1.v.bias": "encoder.mid_block.attentions.0.to_v.bias",
"encoder.mid.attn_1.v.weight": "encoder.mid_block.attentions.0.to_v.weight",
"encoder.mid.block_1.conv1.bias": "encoder.mid_block.resnets.0.conv1.bias",
"encoder.mid.block_1.conv1.weight": "encoder.mid_block.resnets.0.conv1.weight",
"encoder.mid.block_1.conv2.bias": "encoder.mid_block.resnets.0.conv2.bias",
"encoder.mid.block_1.conv2.weight": "encoder.mid_block.resnets.0.conv2.weight",
"encoder.mid.block_1.norm1.bias": "encoder.mid_block.resnets.0.norm1.bias",
"encoder.mid.block_1.norm1.weight": "encoder.mid_block.resnets.0.norm1.weight",
"encoder.mid.block_1.norm2.bias": "encoder.mid_block.resnets.0.norm2.bias",
"encoder.mid.block_1.norm2.weight": "encoder.mid_block.resnets.0.norm2.weight",
"encoder.mid.block_2.conv1.bias": "encoder.mid_block.resnets.1.conv1.bias",
"encoder.mid.block_2.conv1.weight": "encoder.mid_block.resnets.1.conv1.weight",
"encoder.mid.block_2.conv2.bias": "encoder.mid_block.resnets.1.conv2.bias",
"encoder.mid.block_2.conv2.weight": "encoder.mid_block.resnets.1.conv2.weight",
"encoder.mid.block_2.norm1.bias": "encoder.mid_block.resnets.1.norm1.bias",
"encoder.mid.block_2.norm1.weight": "encoder.mid_block.resnets.1.norm1.weight",
"encoder.mid.block_2.norm2.bias": "encoder.mid_block.resnets.1.norm2.bias",
"encoder.mid.block_2.norm2.weight": "encoder.mid_block.resnets.1.norm2.weight",
"encoder.norm_out.bias": "encoder.conv_norm_out.bias",
"encoder.norm_out.weight": "encoder.conv_norm_out.weight",
"post_quant_conv.bias": "post_quant_conv.bias",
"post_quant_conv.weight": "post_quant_conv.weight",
"quant_conv.bias": "quant_conv.bias",
"quant_conv.weight": "quant_conv.weight"
}
def get_diffusers_vae_key_from_ldm_key(target_ldm_key, i=None):
for ldm_key, diffusers_key in vae_ldm_to_diffusers_dict.items():
if i is not None:
ldm_key = ldm_key.replace("{i}", str(i))
diffusers_key = diffusers_key.replace("{i}", str(i))
if ldm_key == target_ldm_key:
return diffusers_key
if ldm_key in vae_ldm_to_diffusers_dict:
return vae_ldm_to_diffusers_dict[ldm_key]
else:
return None
# def get_ldm_vae_key_from_diffusers_key(target_diffusers_key):
# for ldm_key, diffusers_key in vae_ldm_to_diffusers_dict.items():
# if diffusers_key == target_diffusers_key:
# return ldm_key
# return None
def get_ldm_vae_key_from_diffusers_key(target_diffusers_key):
for ldm_key, diffusers_key in vae_ldm_to_diffusers_dict.items():
if "{" in diffusers_key: # if we have a placeholder
# escape special characters in the key, and replace the placeholder with a regex group
pattern = re.escape(diffusers_key).replace("\\{i\\}", "(\\d+)")
match = re.match(pattern, target_diffusers_key)
if match: # if we found a match
return ldm_key.format(i=match.group(1))
elif diffusers_key == target_diffusers_key:
return ldm_key
return None
vae_keys_squished_on_diffusers = [
"decoder.mid_block.attentions.0.to_k.weight",
"decoder.mid_block.attentions.0.to_out.0.weight",
"decoder.mid_block.attentions.0.to_q.weight",
"decoder.mid_block.attentions.0.to_v.weight",
"encoder.mid_block.attentions.0.to_k.weight",
"encoder.mid_block.attentions.0.to_out.0.weight",
"encoder.mid_block.attentions.0.to_q.weight",
"encoder.mid_block.attentions.0.to_v.weight"
]
def convert_diffusers_back_to_ldm(diffusers_vae):
new_state_dict = OrderedDict()
diffusers_state_dict = diffusers_vae.state_dict()
for key, value in diffusers_state_dict.items():
val_to_save = value
if key in vae_keys_squished_on_diffusers:
val_to_save = value.clone()
# (512, 512) diffusers and (512, 512, 1, 1) ldm
val_to_save = val_to_save.unsqueeze(-1).unsqueeze(-1)
ldm_key = get_ldm_vae_key_from_diffusers_key(key)
if ldm_key is not None:
new_state_dict[ldm_key] = val_to_save
else:
# for now add current key
new_state_dict[key] = val_to_save
return new_state_dict
def convert_ldm_vae_checkpoint(checkpoint, config): def convert_ldm_vae_checkpoint(checkpoint, config):
mapping = {}
# extract state dict for VAE # extract state dict for VAE
vae_state_dict = {} vae_state_dict = {}
vae_key = "first_stage_model." vae_key = "first_stage_model."
@@ -390,6 +721,11 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
new_checkpoint = {} new_checkpoint = {}
# for key in list(vae_state_dict.keys()):
# diffusers_key = get_diffusers_vae_key_from_ldm_key(key)
# if diffusers_key is not None:
# new_checkpoint[diffusers_key] = vae_state_dict[key]
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
@@ -411,11 +747,13 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
# Retrieves the keys for the encoder down blocks only # Retrieves the keys for the encoder down blocks only
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
down_blocks = {layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)} down_blocks = {layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in
range(num_down_blocks)}
# Retrieves the keys for the decoder up blocks only # Retrieves the keys for the decoder up blocks only
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
up_blocks = {layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)} up_blocks = {layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in
range(num_up_blocks)}
for i in range(num_down_blocks): for i in range(num_down_blocks):
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
@@ -424,9 +762,11 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
f"encoder.down.{i}.downsample.conv.weight" f"encoder.down.{i}.downsample.conv.weight"
) )
mapping[f"encoder.down.{i}.downsample.conv.weight"] = f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
f"encoder.down.{i}.downsample.conv.bias" f"encoder.down.{i}.downsample.conv.bias"
) )
mapping[f"encoder.down.{i}.downsample.conv.bias"] = f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"
paths = renew_vae_resnet_paths(resnets) paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
@@ -449,15 +789,18 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
for i in range(num_up_blocks): for i in range(num_up_blocks):
block_id = num_up_blocks - 1 - i block_id = num_up_blocks - 1 - i
resnets = [key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key] resnets = [key for key in up_blocks[block_id] if
f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key]
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
f"decoder.up.{block_id}.upsample.conv.weight" f"decoder.up.{block_id}.upsample.conv.weight"
] ]
mapping[f"decoder.up.{block_id}.upsample.conv.weight"] = f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
f"decoder.up.{block_id}.upsample.conv.bias" f"decoder.up.{block_id}.upsample.conv.bias"
] ]
mapping[f"decoder.up.{block_id}.upsample.conv.bias"] = f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"
paths = renew_vae_resnet_paths(resnets) paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
@@ -548,7 +891,7 @@ def convert_ldm_clip_checkpoint_v1(checkpoint):
text_model_dict = {} text_model_dict = {}
for key in keys: for key in keys:
if key.startswith("cond_stage_model.transformer"): if key.startswith("cond_stage_model.transformer"):
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key]
return text_model_dict return text_model_dict
@@ -677,36 +1020,36 @@ def convert_unet_state_dict_to_sd(v2, unet_state_dict):
for j in range(2): for j in range(2):
# loop over resnets/attentions for downblocks # loop over resnets/attentions for downblocks
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i < 3: if i < 3:
# no attention layers in down_blocks.3 # no attention layers in down_blocks.3
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
for j in range(3): for j in range(3):
# loop over resnets/attentions for upblocks # loop over resnets/attentions for upblocks
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
sd_up_res_prefix = f"output_blocks.{3*i + j}.0." sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
if i > 0: if i > 0:
# no attention layers in up_blocks.0 # no attention layers in up_blocks.0
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
if i < 3: if i < 3:
# no downsample in down_blocks.3 # no downsample in down_blocks.3
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
# no upsample in up_blocks.3 # no upsample in up_blocks.3
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}." sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}."
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
hf_mid_atn_prefix = "mid_block.attentions.0." hf_mid_atn_prefix = "mid_block.attentions.0."
@@ -715,7 +1058,7 @@ def convert_unet_state_dict_to_sd(v2, unet_state_dict):
for j in range(2): for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}." hf_mid_res_prefix = f"mid_block.resnets.{j}."
sd_mid_res_prefix = f"middle_block.{2*j}." sd_mid_res_prefix = f"middle_block.{2 * j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
# buyer beware: this is a *brittle* function, # buyer beware: this is a *brittle* function,
@@ -772,20 +1115,20 @@ def convert_vae_state_dict(vae_state_dict):
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"up.{3-i}.upsample." sd_upsample_prefix = f"up.{3 - i}.upsample."
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
# up_blocks have three resnets # up_blocks have three resnets
# also, up blocks in hf are numbered in reverse from sd # also, up blocks in hf are numbered in reverse from sd
for j in range(3): for j in range(3):
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
sd_up_prefix = f"decoder.up.{3-i}.block.{j}." sd_up_prefix = f"decoder.up.{3 - i}.block.{j}."
vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
# this part accounts for mid blocks in both the encoder and the decoder # this part accounts for mid blocks in both the encoder and the decoder
for i in range(2): for i in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{i}." hf_mid_res_prefix = f"mid_block.resnets.{i}."
sd_mid_res_prefix = f"mid.block_{i+1}." sd_mid_res_prefix = f"mid.block_{i + 1}."
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
vae_conversion_map_attn = [ vae_conversion_map_attn = [
@@ -850,7 +1193,7 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"):
for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS: for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
for key in state_dict.keys(): for key in state_dict.keys():
if key.startswith(rep_from): if key.startswith(rep_from):
new_key = rep_to + key[len(rep_from) :] new_key = rep_to + key[len(rep_from):]
key_reps.append((key, new_key)) key_reps.append((key, new_key))
for key, new_key in key_reps: for key, new_key in key_reps:
@@ -861,7 +1204,8 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"):
# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認 # TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None, unet_use_linear_projection_in_v2=False): def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None,
unet_use_linear_projection_in_v2=False):
_, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device) _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device)
# Convert the UNet2DConditionModel model. # Convert the UNet2DConditionModel model.
@@ -990,7 +1334,8 @@ def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=Fals
return new_sd return new_sd
def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None): def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None,
vae=None):
if ckpt_path is not None: if ckpt_path is not None:
# epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む # epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path) checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
@@ -1059,7 +1404,8 @@ def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_p
return key_count return key_count
def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False): def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None,
use_safetensors=False):
if pretrained_model_name_or_path is None: if pretrained_model_name_or_path is None:
# load default settings for v1/v2 # load default settings for v1/v2
if v2: if v2:
@@ -1177,4 +1523,4 @@ if __name__ == "__main__":
for ar in aspect_ratios: for ar in aspect_ratios:
if ar in ars: if ar in ars:
print("error! duplicate ar:", ar) print("error! duplicate ar:", ar)
ars.add(ar) ars.add(ar)

23
toolkit/losses.py Normal file
View File

@@ -0,0 +1,23 @@
import torch
def total_variation(image):
"""
Compute normalized total variation.
Inputs:
- image: PyTorch Variable of shape (N, C, H, W)
Returns:
- TV: total variation normalized by the number of elements
"""
n_elements = image.shape[1] * image.shape[2] * image.shape[3]
return ((torch.sum(torch.abs(image[:, :, :, :-1] - image[:, :, :, 1:])) +
torch.sum(torch.abs(image[:, :, :-1, :] - image[:, :, 1:, :]))) / n_elements)
class ComparativeTotalVariation(torch.nn.Module):
"""
Compute the comparative loss in tv between two images. to match their tv
"""
def forward(self, pred, target):
return torch.abs(total_variation(pred) - total_variation(target))

View File

@@ -127,11 +127,12 @@ class Normalization(nn.Module):
def get_style_model_and_losses( def get_style_model_and_losses(
single_target=False, single_target=False,
device='cuda' if torch.cuda.is_available() else 'cpu' device='cuda' if torch.cuda.is_available() else 'cpu',
output_layer_name=None,
): ):
# content_layers = ['conv_4'] # content_layers = ['conv_4']
# style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5'] # style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
content_layers = ['conv3_2', 'conv4_2'] content_layers = ['conv4_2']
style_layers = ['conv2_1', 'conv3_1', 'conv4_1'] style_layers = ['conv2_1', 'conv3_1', 'conv4_1']
cnn = models.vgg19(pretrained=True).features.to(device).eval() cnn = models.vgg19(pretrained=True).features.to(device).eval()
# normalization module # normalization module
@@ -150,6 +151,8 @@ def get_style_model_and_losses(
block = 1 block = 1
children = list(cnn.children()) children = list(cnn.children())
output_layer = None
for layer in children: for layer in children:
if isinstance(layer, nn.Conv2d): if isinstance(layer, nn.Conv2d):
i += 1 i += 1
@@ -184,11 +187,16 @@ def get_style_model_and_losses(
model.add_module("style_loss_{}_{}".format(block, i), style_loss) model.add_module("style_loss_{}_{}".format(block, i), style_loss)
style_losses.append(style_loss) style_losses.append(style_loss)
if output_layer_name is not None and name == output_layer_name:
output_layer = layer
# now we trim off the layers after the last content and style losses # now we trim off the layers after the last content and style losses
for i in range(len(model) - 1, -1, -1): for i in range(len(model) - 1, -1, -1):
if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss): if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss):
break break
if output_layer_name is not None and model[i].name == output_layer_name:
break
model = model[:(i + 1)] model = model[:(i + 1)]
return model, style_losses, content_losses return model, style_losses, content_losses, output_layer