From 8d8229dfc0b27936b1bcf5ac4ddaa73f74252441 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Tue, 18 Jul 2023 19:34:35 -0600 Subject: [PATCH] Handle conversions back to ldm for saving --- jobs/process/TrainVAEProcess.py | 92 +++++--- testing/test_vae_cycle.py | 112 +++++++++ toolkit/kohya_model_util.py | 406 +++++++++++++++++++++++++++++--- toolkit/losses.py | 23 ++ toolkit/style.py | 14 +- 5 files changed, 586 insertions(+), 61 deletions(-) create mode 100644 testing/test_vae_cycle.py create mode 100644 toolkit/losses.py diff --git a/jobs/process/TrainVAEProcess.py b/jobs/process/TrainVAEProcess.py index f5689143..f000b077 100644 --- a/jobs/process/TrainVAEProcess.py +++ b/jobs/process/TrainVAEProcess.py @@ -1,4 +1,5 @@ import copy +import glob import os import time from collections import OrderedDict @@ -12,8 +13,9 @@ from torch import nn from torchvision.transforms import transforms from jobs.process import BaseTrainProcess -from toolkit.kohya_model_util import load_vae +from toolkit.kohya_model_util import load_vae, convert_diffusers_back_to_ldm from toolkit.data_loader import ImageDataset +from toolkit.losses import ComparativeTotalVariation from toolkit.metadata import get_meta_for_safetensors from toolkit.style import get_style_model_and_losses from toolkit.train_tools import get_torch_dtype @@ -57,7 +59,7 @@ class TrainVAEProcess(BaseTrainProcess): self.content_weight = self.get_conf('content_weight', 0) self.kld_weight = self.get_conf('kld_weight', 0) self.mse_weight = self.get_conf('mse_weight', 1e0) - + self.tv_weight = self.get_conf('tv_weight', 1e0) self.blocks_to_train = self.get_conf('blocks_to_train', ['all']) self.writer = self.job.writer @@ -114,7 +116,7 @@ class TrainVAEProcess(BaseTrainProcess): def setup_vgg19(self): if self.vgg_19 is None: - self.vgg_19, self.style_losses, self.content_losses = get_style_model_and_losses( + self.vgg_19, self.style_losses, self.content_losses, output = get_style_model_and_losses( single_target=True, device=self.device) self.vgg_19.requires_grad_(False) @@ -149,6 +151,15 @@ class TrainVAEProcess(BaseTrainProcess): else: return torch.tensor(0.0, device=self.device) + def get_tv_loss(self, pred, target): + if self.tv_weight > 0: + get_tv_loss = ComparativeTotalVariation() + loss = get_tv_loss(pred, target) + return loss + else: + return torch.tensor(0.0, device=self.device) + + def save(self, step=None): if not os.path.exists(self.save_root): os.makedirs(self.save_root, exist_ok=True) @@ -162,7 +173,7 @@ class TrainVAEProcess(BaseTrainProcess): # prepare meta save_meta = get_meta_for_safetensors(self.meta, self.job.name) - state_dict = self.vae.state_dict() + state_dict = convert_diffusers_back_to_ldm(self.vae) for key in list(state_dict.keys()): v = state_dict[key] @@ -219,6 +230,30 @@ class TrainVAEProcess(BaseTrainProcess): filename = f"{seconds_since_epoch}{step_num}_{i_str}.png" output_img.save(os.path.join(sample_folder, filename)) + def load_vae(self): + path_to_load = self.vae_path + # see if we have a checkpoint in out output to resume from + self.print(f"Looking for latest checkpoint in {self.save_root}") + files = glob.glob(os.path.join(self.save_root, f"{self.job.name}*.safetensors")) + if files and len(files) > 0: + latest_file = max(files, key=os.path.getmtime) + print(f" - Latest checkpoint is: {latest_file}") + path_to_load = latest_file + # todo update step and epoch count + else: + self.print(f" - No checkpoint found, starting from scratch") + # load vae + self.print(f"Loading VAE") + self.print(f" - Loading VAE: {path_to_load}") + if self.vae is None: + self.vae = load_vae(path_to_load, dtype=self.torch_dtype) + + # set decoder to train + self.vae.to(self.device, dtype=self.torch_dtype) + self.vae.requires_grad_(False) + self.vae.eval() + self.vae.decoder.train() + def run(self): super().run() self.load_datasets() @@ -241,16 +276,7 @@ class TrainVAEProcess(BaseTrainProcess): self.print(f" - Max steps: {self.max_steps}") # load vae - self.print(f"Loading VAE") - self.print(f" - Loading VAE: {self.vae_path}") - if self.vae is None: - self.vae = load_vae(self.vae_path, dtype=self.torch_dtype) - - # set decoder to train - self.vae.to(self.device, dtype=self.torch_dtype) - self.vae.requires_grad_(False) - self.vae.eval() - self.vae.decoder.train() + self.load_vae() params = [] @@ -260,18 +286,22 @@ class TrainVAEProcess(BaseTrainProcess): train_all = 'all' in self.blocks_to_train - # mid_block - if train_all or 'mid_block' in self.blocks_to_train: - params += list(self.vae.decoder.mid_block.parameters()) - self.vae.decoder.mid_block.requires_grad_(True) - # up_blocks - if train_all or 'up_blocks' in self.blocks_to_train: - params += list(self.vae.decoder.up_blocks.parameters()) - self.vae.decoder.up_blocks.requires_grad_(True) - # conv_out (single conv layer output) - if train_all or 'conv_out' in self.blocks_to_train: - params += list(self.vae.decoder.conv_out.parameters()) - self.vae.decoder.conv_out.requires_grad_(True) + if train_all: + params = list(self.vae.decoder.parameters()) + self.vae.decoder.requires_grad_(True) + else: + # mid_block + if train_all or 'mid_block' in self.blocks_to_train: + params += list(self.vae.decoder.mid_block.parameters()) + self.vae.decoder.mid_block.requires_grad_(True) + # up_blocks + if train_all or 'up_blocks' in self.blocks_to_train: + params += list(self.vae.decoder.up_blocks.parameters()) + self.vae.decoder.up_blocks.requires_grad_(True) + # conv_out (single conv layer output) + if train_all or 'conv_out' in self.blocks_to_train: + params += list(self.vae.decoder.conv_out.parameters()) + self.vae.decoder.conv_out.requires_grad_(True) if self.style_weight > 0 or self.content_weight > 0: self.setup_vgg19() @@ -305,7 +335,8 @@ class TrainVAEProcess(BaseTrainProcess): "style": [], "content": [], "mse": [], - "kl": [] + "kl": [], + "tv": [], }) epoch_losses = copy.deepcopy(blank_losses) log_losses = copy.deepcopy(blank_losses) @@ -337,8 +368,9 @@ class TrainVAEProcess(BaseTrainProcess): content_loss = self.get_content_loss() * self.content_weight kld_loss = self.get_kld_loss(mu, logvar) * self.kld_weight mse_loss = self.get_mse_loss(pred, batch) * self.mse_weight + tv_loss = self.get_tv_loss(pred, batch) * self.tv_weight - loss = style_loss + content_loss + kld_loss + mse_loss + loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss # Backward pass and optimization optimizer.zero_grad() @@ -358,6 +390,8 @@ class TrainVAEProcess(BaseTrainProcess): loss_string += f" kld: {kld_loss.item():.2e}" if self.mse_weight > 0: loss_string += f" mse: {mse_loss.item():.2e}" + if self.tv_weight > 0: + loss_string += f" tv: {tv_loss.item():.2e}" learning_rate = optimizer.param_groups[0]['lr'] self.progress_bar.set_postfix_str(f"LR: {learning_rate:.2e} {loss_string}") @@ -369,12 +403,14 @@ class TrainVAEProcess(BaseTrainProcess): epoch_losses["content"].append(content_loss.item()) epoch_losses["mse"].append(mse_loss.item()) epoch_losses["kl"].append(kld_loss.item()) + epoch_losses["tv"].append(tv_loss.item()) log_losses["total"].append(loss_value) log_losses["style"].append(style_loss.item()) log_losses["content"].append(content_loss.item()) log_losses["mse"].append(mse_loss.item()) log_losses["kl"].append(kld_loss.item()) + log_losses["tv"].append(tv_loss.item()) if step != 0: if self.sample_every and step % self.sample_every == 0: diff --git a/testing/test_vae_cycle.py b/testing/test_vae_cycle.py new file mode 100644 index 00000000..175e8f8f --- /dev/null +++ b/testing/test_vae_cycle.py @@ -0,0 +1,112 @@ +import os + +import torch +from safetensors.torch import load_file +from collections import OrderedDict +from toolkit.kohya_model_util import load_vae, convert_diffusers_back_to_ldm, vae_keys_squished_on_diffusers +import json +# this was just used to match the vae keys to the diffusers keys +# you probably wont need this. Unless they change them.... again... again +# on second thought, you probably will + +device = torch.device('cpu') +dtype = torch.float32 +vae_path = '/mnt/Models/stable-diffusion/models/VAE/vae-ft-mse-840000-ema-pruned/vae-ft-mse-840000-ema-pruned.safetensors' + +find_matches = False + +state_dict_ldm = load_file(vae_path) +diffusers_vae = load_vae(vae_path, dtype=torch.float32).to(device) + +ldm_keys = state_dict_ldm.keys() + +matched_keys = {} +duplicated_keys = { + +} + +if find_matches: + # find values that match with a very low mse + for ldm_key in ldm_keys: + ldm_value = state_dict_ldm[ldm_key] + for diffusers_key in list(diffusers_vae.state_dict().keys()): + diffusers_value = diffusers_vae.state_dict()[diffusers_key] + if diffusers_key in vae_keys_squished_on_diffusers: + diffusers_value = diffusers_value.clone().unsqueeze(-1).unsqueeze(-1) + # if they are not same shape, skip + if ldm_value.shape != diffusers_value.shape: + continue + mse = torch.nn.functional.mse_loss(ldm_value, diffusers_value) + if mse < 1e-6: + if ldm_key in list(matched_keys.keys()): + print(f'{ldm_key} already matched to {matched_keys[ldm_key]}') + if ldm_key in duplicated_keys: + duplicated_keys[ldm_key].append(diffusers_key) + else: + duplicated_keys[ldm_key] = [diffusers_key] + continue + matched_keys[ldm_key] = diffusers_key + is_matched = True + break + + print(f'Found {len(matched_keys)} matches') + +dif_to_ldm_state_dict = convert_diffusers_back_to_ldm(diffusers_vae) +dif_to_ldm_state_dict_keys = list(dif_to_ldm_state_dict.keys()) +keys_in_both = [] + +keys_not_in_diffusers = [] +for key in ldm_keys: + if key not in dif_to_ldm_state_dict_keys: + keys_not_in_diffusers.append(key) + +keys_not_in_ldm = [] +for key in dif_to_ldm_state_dict_keys: + if key not in ldm_keys: + keys_not_in_ldm.append(key) + +keys_in_both = [] +for key in ldm_keys: + if key in dif_to_ldm_state_dict_keys: + keys_in_both.append(key) + +# sort them +keys_not_in_diffusers.sort() +keys_not_in_ldm.sort() +keys_in_both.sort() + +# print(f'Keys in LDM but not in Diffusers: {len(keys_not_in_diffusers)}{keys_not_in_diffusers}') +# print(f'Keys in Diffusers but not in LDM: {len(keys_not_in_ldm)}{keys_not_in_ldm}') +# print(f'Keys in both: {len(keys_in_both)}{keys_in_both}') + +json_data = { + "both": keys_in_both, + "ldm": keys_not_in_diffusers, + "diffusers": keys_not_in_ldm +} +json_data = json.dumps(json_data, indent=4) + +remaining_diffusers_values = OrderedDict() +for key in keys_not_in_ldm: + remaining_diffusers_values[key] = dif_to_ldm_state_dict[key] + +# print(remaining_diffusers_values.keys()) + +remaining_ldm_values = OrderedDict() +for key in keys_not_in_diffusers: + remaining_ldm_values[key] = state_dict_ldm[key] + +# print(json_data) + +project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +json_save_path = os.path.join(project_root, 'config', 'keys.json') +json_matched_save_path = os.path.join(project_root, 'config', 'matched.json') +json_duped_save_path = os.path.join(project_root, 'config', 'duped.json') + +with open(json_save_path, 'w') as f: + f.write(json_data) +if find_matches: + with open(json_matched_save_path, 'w') as f: + f.write(json.dumps(matched_keys, indent=4)) + with open(json_duped_save_path, 'w') as f: + f.write(json.dumps(duplicated_keys, indent=4)) diff --git a/toolkit/kohya_model_util.py b/toolkit/kohya_model_util.py index 8aba172f..5e976f0b 100644 --- a/toolkit/kohya_model_util.py +++ b/toolkit/kohya_model_util.py @@ -1,16 +1,19 @@ # mostly from https://github.com/kohya-ss/sd-scripts/blob/main/library/model_util.py # I am infinitely grateful to @kohya-ss for their amazing work in this field. # This version is updated to handle the latest version of the diffusers library. - +import json # v1: split from train_db_fixed.py. # v2: support safetensors import math import os +import re + import torch from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel from safetensors.torch import load_file, save_file +from collections import OrderedDict # DiffUsers版StableDiffusionのモデルパラメータ NUM_TRAIN_TIMESTEPS = 1000 @@ -151,7 +154,7 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): def assign_to_checkpoint( - paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None + paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None ): """ This does the final conversion step: take locally converted weights and apply a global renaming @@ -228,6 +231,7 @@ def linear_transformer_to_conv(checkpoint): def convert_ldm_unet_checkpoint(v2, checkpoint, config): + mapping = {} """ Takes a state dict and a config, and returns a converted checkpoint. """ @@ -258,33 +262,40 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config): # Retrieves the keys for the input blocks only num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) input_blocks = { - layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] for layer_id in range(num_input_blocks) + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] for layer_id in + range(num_input_blocks) } # Retrieves the keys for the middle blocks only num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) middle_blocks = { - layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key] for layer_id in range(num_middle_blocks) + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key] for layer_id in + range(num_middle_blocks) } # Retrieves the keys for the output blocks only num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) output_blocks = { - layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] for layer_id in range(num_output_blocks) + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] for layer_id in + range(num_output_blocks) } for i in range(1, num_input_blocks): block_id = (i - 1) // (config["layers_per_block"] + 1) layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) - resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key] + resnets = [key for key in input_blocks[i] if + f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key] attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] if f"input_blocks.{i}.0.op.weight" in unet_state_dict: new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( f"input_blocks.{i}.0.op.weight" ) - new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(f"input_blocks.{i}.0.op.bias") + mapping[f'input_blocks.{i}.0.op.weight'] = f"down_blocks.{block_id}.downsamplers.0.conv.weight" + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.bias") + mapping[f'input_blocks.{i}.0.op.bias'] = f"down_blocks.{block_id}.downsamplers.0.conv.bias" paths = renew_resnet_paths(resnets) meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} @@ -293,7 +304,8 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config): if len(attentions): paths = renew_attention_paths(attentions) meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} - assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) + assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], + config=config) resnet_0 = middle_blocks[0] attentions = middle_blocks[1] @@ -307,7 +319,8 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config): attentions_paths = renew_attention_paths(attentions) meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} - assign_to_checkpoint(attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) + assign_to_checkpoint(attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], + config=config) for i in range(num_output_blocks): block_id = i // (config["layers_per_block"] + 1) @@ -330,7 +343,8 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config): paths = renew_resnet_paths(resnets) meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} - assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) + assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], + config=config) # オリジナル: # if ["conv.weight", "conv.bias"] in output_block_list.values(): @@ -359,7 +373,8 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config): "old": f"output_blocks.{i}.1", "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", } - assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) + assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], + config=config) else: resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) for path in resnet_0_paths: @@ -373,10 +388,326 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config): if v2 and not config.get('use_linear_projection', False): linear_transformer_to_conv(new_checkpoint) + # print("mapping: ", json.dumps(mapping, indent=4)) return new_checkpoint +# ldm key: diffusers key +vae_ldm_to_diffusers_dict = { + "decoder.conv_in.bias": "decoder.conv_in.bias", + "decoder.conv_in.weight": "decoder.conv_in.weight", + "decoder.conv_out.bias": "decoder.conv_out.bias", + "decoder.conv_out.weight": "decoder.conv_out.weight", + "decoder.mid.attn_1.k.bias": "decoder.mid_block.attentions.0.to_k.bias", + "decoder.mid.attn_1.k.weight": "decoder.mid_block.attentions.0.to_k.weight", + "decoder.mid.attn_1.norm.bias": "decoder.mid_block.attentions.0.group_norm.bias", + "decoder.mid.attn_1.norm.weight": "decoder.mid_block.attentions.0.group_norm.weight", + "decoder.mid.attn_1.proj_out.bias": "decoder.mid_block.attentions.0.to_out.0.bias", + "decoder.mid.attn_1.proj_out.weight": "decoder.mid_block.attentions.0.to_out.0.weight", + "decoder.mid.attn_1.q.bias": "decoder.mid_block.attentions.0.to_q.bias", + "decoder.mid.attn_1.q.weight": "decoder.mid_block.attentions.0.to_q.weight", + "decoder.mid.attn_1.v.bias": "decoder.mid_block.attentions.0.to_v.bias", + "decoder.mid.attn_1.v.weight": "decoder.mid_block.attentions.0.to_v.weight", + "decoder.mid.block_1.conv1.bias": "decoder.mid_block.resnets.0.conv1.bias", + "decoder.mid.block_1.conv1.weight": "decoder.mid_block.resnets.0.conv1.weight", + "decoder.mid.block_1.conv2.bias": "decoder.mid_block.resnets.0.conv2.bias", + "decoder.mid.block_1.conv2.weight": "decoder.mid_block.resnets.0.conv2.weight", + "decoder.mid.block_1.norm1.bias": "decoder.mid_block.resnets.0.norm1.bias", + "decoder.mid.block_1.norm1.weight": "decoder.mid_block.resnets.0.norm1.weight", + "decoder.mid.block_1.norm2.bias": "decoder.mid_block.resnets.0.norm2.bias", + "decoder.mid.block_1.norm2.weight": "decoder.mid_block.resnets.0.norm2.weight", + "decoder.mid.block_2.conv1.bias": "decoder.mid_block.resnets.1.conv1.bias", + "decoder.mid.block_2.conv1.weight": "decoder.mid_block.resnets.1.conv1.weight", + "decoder.mid.block_2.conv2.bias": "decoder.mid_block.resnets.1.conv2.bias", + "decoder.mid.block_2.conv2.weight": "decoder.mid_block.resnets.1.conv2.weight", + "decoder.mid.block_2.norm1.bias": "decoder.mid_block.resnets.1.norm1.bias", + "decoder.mid.block_2.norm1.weight": "decoder.mid_block.resnets.1.norm1.weight", + "decoder.mid.block_2.norm2.bias": "decoder.mid_block.resnets.1.norm2.bias", + "decoder.mid.block_2.norm2.weight": "decoder.mid_block.resnets.1.norm2.weight", + "decoder.norm_out.bias": "decoder.conv_norm_out.bias", + "decoder.norm_out.weight": "decoder.conv_norm_out.weight", + "decoder.up.0.block.0.conv1.bias": "decoder.up_blocks.3.resnets.0.conv1.bias", + "decoder.up.0.block.0.conv1.weight": "decoder.up_blocks.3.resnets.0.conv1.weight", + "decoder.up.0.block.0.conv2.bias": "decoder.up_blocks.3.resnets.0.conv2.bias", + "decoder.up.0.block.0.conv2.weight": "decoder.up_blocks.3.resnets.0.conv2.weight", + "decoder.up.0.block.0.nin_shortcut.bias": "decoder.up_blocks.3.resnets.0.conv_shortcut.bias", + "decoder.up.0.block.0.nin_shortcut.weight": "decoder.up_blocks.3.resnets.0.conv_shortcut.weight", + "decoder.up.0.block.0.norm1.bias": "decoder.up_blocks.3.resnets.0.norm1.bias", + "decoder.up.0.block.0.norm1.weight": "decoder.up_blocks.3.resnets.0.norm1.weight", + "decoder.up.0.block.0.norm2.bias": "decoder.up_blocks.3.resnets.0.norm2.bias", + "decoder.up.0.block.0.norm2.weight": "decoder.up_blocks.3.resnets.0.norm2.weight", + "decoder.up.0.block.1.conv1.bias": "decoder.up_blocks.3.resnets.1.conv1.bias", + "decoder.up.0.block.1.conv1.weight": "decoder.up_blocks.3.resnets.1.conv1.weight", + "decoder.up.0.block.1.conv2.bias": "decoder.up_blocks.3.resnets.1.conv2.bias", + "decoder.up.0.block.1.conv2.weight": "decoder.up_blocks.3.resnets.1.conv2.weight", + "decoder.up.0.block.1.norm1.bias": "decoder.up_blocks.3.resnets.1.norm1.bias", + "decoder.up.0.block.1.norm1.weight": "decoder.up_blocks.3.resnets.1.norm1.weight", + "decoder.up.0.block.1.norm2.bias": "decoder.up_blocks.3.resnets.1.norm2.bias", + "decoder.up.0.block.1.norm2.weight": "decoder.up_blocks.3.resnets.1.norm2.weight", + "decoder.up.0.block.2.conv1.bias": "decoder.up_blocks.3.resnets.2.conv1.bias", + "decoder.up.0.block.2.conv1.weight": "decoder.up_blocks.3.resnets.2.conv1.weight", + "decoder.up.0.block.2.conv2.bias": "decoder.up_blocks.3.resnets.2.conv2.bias", + "decoder.up.0.block.2.conv2.weight": "decoder.up_blocks.3.resnets.2.conv2.weight", + "decoder.up.0.block.2.norm1.bias": "decoder.up_blocks.3.resnets.2.norm1.bias", + "decoder.up.0.block.2.norm1.weight": "decoder.up_blocks.3.resnets.2.norm1.weight", + "decoder.up.0.block.2.norm2.bias": "decoder.up_blocks.3.resnets.2.norm2.bias", + "decoder.up.0.block.2.norm2.weight": "decoder.up_blocks.3.resnets.2.norm2.weight", + "decoder.up.1.block.0.conv1.bias": "decoder.up_blocks.2.resnets.0.conv1.bias", + "decoder.up.1.block.0.conv1.weight": "decoder.up_blocks.2.resnets.0.conv1.weight", + "decoder.up.1.block.0.conv2.bias": "decoder.up_blocks.2.resnets.0.conv2.bias", + "decoder.up.1.block.0.conv2.weight": "decoder.up_blocks.2.resnets.0.conv2.weight", + "decoder.up.1.block.0.nin_shortcut.bias": "decoder.up_blocks.2.resnets.0.conv_shortcut.bias", + "decoder.up.1.block.0.nin_shortcut.weight": "decoder.up_blocks.2.resnets.0.conv_shortcut.weight", + "decoder.up.1.block.0.norm1.bias": "decoder.up_blocks.2.resnets.0.norm1.bias", + "decoder.up.1.block.0.norm1.weight": "decoder.up_blocks.2.resnets.0.norm1.weight", + "decoder.up.1.block.0.norm2.bias": "decoder.up_blocks.2.resnets.0.norm2.bias", + "decoder.up.1.block.0.norm2.weight": "decoder.up_blocks.2.resnets.0.norm2.weight", + "decoder.up.1.block.1.conv1.bias": "decoder.up_blocks.2.resnets.1.conv1.bias", + "decoder.up.1.block.1.conv1.weight": "decoder.up_blocks.2.resnets.1.conv1.weight", + "decoder.up.1.block.1.conv2.bias": "decoder.up_blocks.2.resnets.1.conv2.bias", + "decoder.up.1.block.1.conv2.weight": "decoder.up_blocks.2.resnets.1.conv2.weight", + "decoder.up.1.block.1.norm1.bias": "decoder.up_blocks.2.resnets.1.norm1.bias", + "decoder.up.1.block.1.norm1.weight": "decoder.up_blocks.2.resnets.1.norm1.weight", + "decoder.up.1.block.1.norm2.bias": "decoder.up_blocks.2.resnets.1.norm2.bias", + "decoder.up.1.block.1.norm2.weight": "decoder.up_blocks.2.resnets.1.norm2.weight", + "decoder.up.1.block.2.conv1.bias": "decoder.up_blocks.2.resnets.2.conv1.bias", + "decoder.up.1.block.2.conv1.weight": "decoder.up_blocks.2.resnets.2.conv1.weight", + "decoder.up.1.block.2.conv2.bias": "decoder.up_blocks.2.resnets.2.conv2.bias", + "decoder.up.1.block.2.conv2.weight": "decoder.up_blocks.2.resnets.2.conv2.weight", + "decoder.up.1.block.2.norm1.bias": "decoder.up_blocks.2.resnets.2.norm1.bias", + "decoder.up.1.block.2.norm1.weight": "decoder.up_blocks.2.resnets.2.norm1.weight", + "decoder.up.1.block.2.norm2.bias": "decoder.up_blocks.2.resnets.2.norm2.bias", + "decoder.up.1.block.2.norm2.weight": "decoder.up_blocks.2.resnets.2.norm2.weight", + "decoder.up.1.upsample.conv.bias": "decoder.up_blocks.2.upsamplers.0.conv.bias", + "decoder.up.1.upsample.conv.weight": "decoder.up_blocks.2.upsamplers.0.conv.weight", + "decoder.up.2.block.0.conv1.bias": "decoder.up_blocks.1.resnets.0.conv1.bias", + "decoder.up.2.block.0.conv1.weight": "decoder.up_blocks.1.resnets.0.conv1.weight", + "decoder.up.2.block.0.conv2.bias": "decoder.up_blocks.1.resnets.0.conv2.bias", + "decoder.up.2.block.0.conv2.weight": "decoder.up_blocks.1.resnets.0.conv2.weight", + "decoder.up.2.block.0.norm1.bias": "decoder.up_blocks.1.resnets.0.norm1.bias", + "decoder.up.2.block.0.norm1.weight": "decoder.up_blocks.1.resnets.0.norm1.weight", + "decoder.up.2.block.0.norm2.bias": "decoder.up_blocks.1.resnets.0.norm2.bias", + "decoder.up.2.block.0.norm2.weight": "decoder.up_blocks.1.resnets.0.norm2.weight", + "decoder.up.2.block.1.conv1.bias": "decoder.up_blocks.1.resnets.1.conv1.bias", + "decoder.up.2.block.1.conv1.weight": "decoder.up_blocks.1.resnets.1.conv1.weight", + "decoder.up.2.block.1.conv2.bias": "decoder.up_blocks.1.resnets.1.conv2.bias", + "decoder.up.2.block.1.conv2.weight": "decoder.up_blocks.1.resnets.1.conv2.weight", + "decoder.up.2.block.1.norm1.bias": "decoder.up_blocks.1.resnets.1.norm1.bias", + "decoder.up.2.block.1.norm1.weight": "decoder.up_blocks.1.resnets.1.norm1.weight", + "decoder.up.2.block.1.norm2.bias": "decoder.up_blocks.1.resnets.1.norm2.bias", + "decoder.up.2.block.1.norm2.weight": "decoder.up_blocks.1.resnets.1.norm2.weight", + "decoder.up.2.block.2.conv1.bias": "decoder.up_blocks.1.resnets.2.conv1.bias", + "decoder.up.2.block.2.conv1.weight": "decoder.up_blocks.1.resnets.2.conv1.weight", + "decoder.up.2.block.2.conv2.bias": "decoder.up_blocks.1.resnets.2.conv2.bias", + "decoder.up.2.block.2.conv2.weight": "decoder.up_blocks.1.resnets.2.conv2.weight", + "decoder.up.2.block.2.norm1.bias": "decoder.up_blocks.1.resnets.2.norm1.bias", + "decoder.up.2.block.2.norm1.weight": "decoder.up_blocks.1.resnets.2.norm1.weight", + "decoder.up.2.block.2.norm2.bias": "decoder.up_blocks.1.resnets.2.norm2.bias", + "decoder.up.2.block.2.norm2.weight": "decoder.up_blocks.1.resnets.2.norm2.weight", + "decoder.up.2.upsample.conv.bias": "decoder.up_blocks.1.upsamplers.0.conv.bias", + "decoder.up.2.upsample.conv.weight": "decoder.up_blocks.1.upsamplers.0.conv.weight", + "decoder.up.3.block.0.conv1.bias": "decoder.up_blocks.0.resnets.0.conv1.bias", + "decoder.up.3.block.0.conv1.weight": "decoder.up_blocks.0.resnets.0.conv1.weight", + "decoder.up.3.block.0.conv2.bias": "decoder.up_blocks.0.resnets.0.conv2.bias", + "decoder.up.3.block.0.conv2.weight": "decoder.up_blocks.0.resnets.0.conv2.weight", + "decoder.up.3.block.0.norm1.bias": "decoder.up_blocks.0.resnets.0.norm1.bias", + "decoder.up.3.block.0.norm1.weight": "decoder.up_blocks.0.resnets.0.norm1.weight", + "decoder.up.3.block.0.norm2.bias": "decoder.up_blocks.0.resnets.0.norm2.bias", + "decoder.up.3.block.0.norm2.weight": "decoder.up_blocks.0.resnets.0.norm2.weight", + "decoder.up.3.block.1.conv1.bias": "decoder.up_blocks.0.resnets.1.conv1.bias", + "decoder.up.3.block.1.conv1.weight": "decoder.up_blocks.0.resnets.1.conv1.weight", + "decoder.up.3.block.1.conv2.bias": "decoder.up_blocks.0.resnets.1.conv2.bias", + "decoder.up.3.block.1.conv2.weight": "decoder.up_blocks.0.resnets.1.conv2.weight", + "decoder.up.3.block.1.norm1.bias": "decoder.up_blocks.0.resnets.1.norm1.bias", + "decoder.up.3.block.1.norm1.weight": "decoder.up_blocks.0.resnets.1.norm1.weight", + "decoder.up.3.block.1.norm2.bias": "decoder.up_blocks.0.resnets.1.norm2.bias", + "decoder.up.3.block.1.norm2.weight": "decoder.up_blocks.0.resnets.1.norm2.weight", + "decoder.up.3.block.2.conv1.bias": "decoder.up_blocks.0.resnets.2.conv1.bias", + "decoder.up.3.block.2.conv1.weight": "decoder.up_blocks.0.resnets.2.conv1.weight", + "decoder.up.3.block.2.conv2.bias": "decoder.up_blocks.0.resnets.2.conv2.bias", + "decoder.up.3.block.2.conv2.weight": "decoder.up_blocks.0.resnets.2.conv2.weight", + "decoder.up.3.block.2.norm1.bias": "decoder.up_blocks.0.resnets.2.norm1.bias", + "decoder.up.3.block.2.norm1.weight": "decoder.up_blocks.0.resnets.2.norm1.weight", + "decoder.up.3.block.2.norm2.bias": "decoder.up_blocks.0.resnets.2.norm2.bias", + "decoder.up.3.block.2.norm2.weight": "decoder.up_blocks.0.resnets.2.norm2.weight", + "decoder.up.3.upsample.conv.bias": "decoder.up_blocks.0.upsamplers.0.conv.bias", + "decoder.up.3.upsample.conv.weight": "decoder.up_blocks.0.upsamplers.0.conv.weight", + "encoder.conv_in.bias": "encoder.conv_in.bias", + "encoder.conv_in.weight": "encoder.conv_in.weight", + "encoder.conv_out.bias": "encoder.conv_out.bias", + "encoder.conv_out.weight": "encoder.conv_out.weight", + "encoder.down.0.block.0.conv1.bias": "encoder.down_blocks.0.resnets.0.conv1.bias", + "encoder.down.0.block.0.conv1.weight": "encoder.down_blocks.0.resnets.0.conv1.weight", + "encoder.down.0.block.0.conv2.bias": "encoder.down_blocks.0.resnets.0.conv2.bias", + "encoder.down.0.block.0.conv2.weight": "encoder.down_blocks.0.resnets.0.conv2.weight", + "encoder.down.0.block.0.norm1.bias": "encoder.down_blocks.0.resnets.0.norm1.bias", + "encoder.down.0.block.0.norm1.weight": "encoder.down_blocks.0.resnets.0.norm1.weight", + "encoder.down.0.block.0.norm2.bias": "encoder.down_blocks.0.resnets.0.norm2.bias", + "encoder.down.0.block.0.norm2.weight": "encoder.down_blocks.0.resnets.0.norm2.weight", + "encoder.down.0.block.1.conv1.bias": "encoder.down_blocks.0.resnets.1.conv1.bias", + "encoder.down.0.block.1.conv1.weight": "encoder.down_blocks.0.resnets.1.conv1.weight", + "encoder.down.0.block.1.conv2.bias": "encoder.down_blocks.0.resnets.1.conv2.bias", + "encoder.down.0.block.1.conv2.weight": "encoder.down_blocks.0.resnets.1.conv2.weight", + "encoder.down.0.block.1.norm1.bias": "encoder.down_blocks.0.resnets.1.norm1.bias", + "encoder.down.0.block.1.norm1.weight": "encoder.down_blocks.0.resnets.1.norm1.weight", + "encoder.down.0.block.1.norm2.bias": "encoder.down_blocks.0.resnets.1.norm2.bias", + "encoder.down.0.block.1.norm2.weight": "encoder.down_blocks.0.resnets.1.norm2.weight", + "encoder.down.0.downsample.conv.bias": "encoder.down_blocks.0.downsamplers.0.conv.bias", + "encoder.down.0.downsample.conv.weight": "encoder.down_blocks.0.downsamplers.0.conv.weight", + "encoder.down.1.block.0.conv1.bias": "encoder.down_blocks.1.resnets.0.conv1.bias", + "encoder.down.1.block.0.conv1.weight": "encoder.down_blocks.1.resnets.0.conv1.weight", + "encoder.down.1.block.0.conv2.bias": "encoder.down_blocks.1.resnets.0.conv2.bias", + "encoder.down.1.block.0.conv2.weight": "encoder.down_blocks.1.resnets.0.conv2.weight", + "encoder.down.1.block.0.nin_shortcut.bias": "encoder.down_blocks.1.resnets.0.conv_shortcut.bias", + "encoder.down.1.block.0.nin_shortcut.weight": "encoder.down_blocks.1.resnets.0.conv_shortcut.weight", + "encoder.down.1.block.0.norm1.bias": "encoder.down_blocks.1.resnets.0.norm1.bias", + "encoder.down.1.block.0.norm1.weight": "encoder.down_blocks.1.resnets.0.norm1.weight", + "encoder.down.1.block.0.norm2.bias": "encoder.down_blocks.1.resnets.0.norm2.bias", + "encoder.down.1.block.0.norm2.weight": "encoder.down_blocks.1.resnets.0.norm2.weight", + "encoder.down.1.block.1.conv1.bias": "encoder.down_blocks.1.resnets.1.conv1.bias", + "encoder.down.1.block.1.conv1.weight": "encoder.down_blocks.1.resnets.1.conv1.weight", + "encoder.down.1.block.1.conv2.bias": "encoder.down_blocks.1.resnets.1.conv2.bias", + "encoder.down.1.block.1.conv2.weight": "encoder.down_blocks.1.resnets.1.conv2.weight", + "encoder.down.1.block.1.norm1.bias": "encoder.down_blocks.1.resnets.1.norm1.bias", + "encoder.down.1.block.1.norm1.weight": "encoder.down_blocks.1.resnets.1.norm1.weight", + "encoder.down.1.block.1.norm2.bias": "encoder.down_blocks.1.resnets.1.norm2.bias", + "encoder.down.1.block.1.norm2.weight": "encoder.down_blocks.1.resnets.1.norm2.weight", + "encoder.down.1.downsample.conv.bias": "encoder.down_blocks.1.downsamplers.0.conv.bias", + "encoder.down.1.downsample.conv.weight": "encoder.down_blocks.1.downsamplers.0.conv.weight", + "encoder.down.2.block.0.conv1.bias": "encoder.down_blocks.2.resnets.0.conv1.bias", + "encoder.down.2.block.0.conv1.weight": "encoder.down_blocks.2.resnets.0.conv1.weight", + "encoder.down.2.block.0.conv2.bias": "encoder.down_blocks.2.resnets.0.conv2.bias", + "encoder.down.2.block.0.conv2.weight": "encoder.down_blocks.2.resnets.0.conv2.weight", + "encoder.down.2.block.0.nin_shortcut.bias": "encoder.down_blocks.2.resnets.0.conv_shortcut.bias", + "encoder.down.2.block.0.nin_shortcut.weight": "encoder.down_blocks.2.resnets.0.conv_shortcut.weight", + "encoder.down.2.block.0.norm1.bias": "encoder.down_blocks.2.resnets.0.norm1.bias", + "encoder.down.2.block.0.norm1.weight": "encoder.down_blocks.2.resnets.0.norm1.weight", + "encoder.down.2.block.0.norm2.bias": "encoder.down_blocks.2.resnets.0.norm2.bias", + "encoder.down.2.block.0.norm2.weight": "encoder.down_blocks.2.resnets.0.norm2.weight", + "encoder.down.2.block.1.conv1.bias": "encoder.down_blocks.2.resnets.1.conv1.bias", + "encoder.down.2.block.1.conv1.weight": "encoder.down_blocks.2.resnets.1.conv1.weight", + "encoder.down.2.block.1.conv2.bias": "encoder.down_blocks.2.resnets.1.conv2.bias", + "encoder.down.2.block.1.conv2.weight": "encoder.down_blocks.2.resnets.1.conv2.weight", + "encoder.down.2.block.1.norm1.bias": "encoder.down_blocks.2.resnets.1.norm1.bias", + "encoder.down.2.block.1.norm1.weight": "encoder.down_blocks.2.resnets.1.norm1.weight", + "encoder.down.2.block.1.norm2.bias": "encoder.down_blocks.2.resnets.1.norm2.bias", + "encoder.down.2.block.1.norm2.weight": "encoder.down_blocks.2.resnets.1.norm2.weight", + "encoder.down.2.downsample.conv.bias": "encoder.down_blocks.2.downsamplers.0.conv.bias", + "encoder.down.2.downsample.conv.weight": "encoder.down_blocks.2.downsamplers.0.conv.weight", + "encoder.down.3.block.0.conv1.bias": "encoder.down_blocks.3.resnets.0.conv1.bias", + "encoder.down.3.block.0.conv1.weight": "encoder.down_blocks.3.resnets.0.conv1.weight", + "encoder.down.3.block.0.conv2.bias": "encoder.down_blocks.3.resnets.0.conv2.bias", + "encoder.down.3.block.0.conv2.weight": "encoder.down_blocks.3.resnets.0.conv2.weight", + "encoder.down.3.block.0.norm1.bias": "encoder.down_blocks.3.resnets.0.norm1.bias", + "encoder.down.3.block.0.norm1.weight": "encoder.down_blocks.3.resnets.0.norm1.weight", + "encoder.down.3.block.0.norm2.bias": "encoder.down_blocks.3.resnets.0.norm2.bias", + "encoder.down.3.block.0.norm2.weight": "encoder.down_blocks.3.resnets.0.norm2.weight", + "encoder.down.3.block.1.conv1.bias": "encoder.down_blocks.3.resnets.1.conv1.bias", + "encoder.down.3.block.1.conv1.weight": "encoder.down_blocks.3.resnets.1.conv1.weight", + "encoder.down.3.block.1.conv2.bias": "encoder.down_blocks.3.resnets.1.conv2.bias", + "encoder.down.3.block.1.conv2.weight": "encoder.down_blocks.3.resnets.1.conv2.weight", + "encoder.down.3.block.1.norm1.bias": "encoder.down_blocks.3.resnets.1.norm1.bias", + "encoder.down.3.block.1.norm1.weight": "encoder.down_blocks.3.resnets.1.norm1.weight", + "encoder.down.3.block.1.norm2.bias": "encoder.down_blocks.3.resnets.1.norm2.bias", + "encoder.down.3.block.1.norm2.weight": "encoder.down_blocks.3.resnets.1.norm2.weight", + "encoder.mid.attn_1.k.bias": "encoder.mid_block.attentions.0.to_k.bias", + "encoder.mid.attn_1.k.weight": "encoder.mid_block.attentions.0.to_k.weight", + "encoder.mid.attn_1.norm.bias": "encoder.mid_block.attentions.0.group_norm.bias", + "encoder.mid.attn_1.norm.weight": "encoder.mid_block.attentions.0.group_norm.weight", + "encoder.mid.attn_1.proj_out.bias": "encoder.mid_block.attentions.0.to_out.0.bias", + "encoder.mid.attn_1.proj_out.weight": "encoder.mid_block.attentions.0.to_out.0.weight", + "encoder.mid.attn_1.q.bias": "encoder.mid_block.attentions.0.to_q.bias", + "encoder.mid.attn_1.q.weight": "encoder.mid_block.attentions.0.to_q.weight", + "encoder.mid.attn_1.v.bias": "encoder.mid_block.attentions.0.to_v.bias", + "encoder.mid.attn_1.v.weight": "encoder.mid_block.attentions.0.to_v.weight", + "encoder.mid.block_1.conv1.bias": "encoder.mid_block.resnets.0.conv1.bias", + "encoder.mid.block_1.conv1.weight": "encoder.mid_block.resnets.0.conv1.weight", + "encoder.mid.block_1.conv2.bias": "encoder.mid_block.resnets.0.conv2.bias", + "encoder.mid.block_1.conv2.weight": "encoder.mid_block.resnets.0.conv2.weight", + "encoder.mid.block_1.norm1.bias": "encoder.mid_block.resnets.0.norm1.bias", + "encoder.mid.block_1.norm1.weight": "encoder.mid_block.resnets.0.norm1.weight", + "encoder.mid.block_1.norm2.bias": "encoder.mid_block.resnets.0.norm2.bias", + "encoder.mid.block_1.norm2.weight": "encoder.mid_block.resnets.0.norm2.weight", + "encoder.mid.block_2.conv1.bias": "encoder.mid_block.resnets.1.conv1.bias", + "encoder.mid.block_2.conv1.weight": "encoder.mid_block.resnets.1.conv1.weight", + "encoder.mid.block_2.conv2.bias": "encoder.mid_block.resnets.1.conv2.bias", + "encoder.mid.block_2.conv2.weight": "encoder.mid_block.resnets.1.conv2.weight", + "encoder.mid.block_2.norm1.bias": "encoder.mid_block.resnets.1.norm1.bias", + "encoder.mid.block_2.norm1.weight": "encoder.mid_block.resnets.1.norm1.weight", + "encoder.mid.block_2.norm2.bias": "encoder.mid_block.resnets.1.norm2.bias", + "encoder.mid.block_2.norm2.weight": "encoder.mid_block.resnets.1.norm2.weight", + "encoder.norm_out.bias": "encoder.conv_norm_out.bias", + "encoder.norm_out.weight": "encoder.conv_norm_out.weight", + "post_quant_conv.bias": "post_quant_conv.bias", + "post_quant_conv.weight": "post_quant_conv.weight", + "quant_conv.bias": "quant_conv.bias", + "quant_conv.weight": "quant_conv.weight" +} + + +def get_diffusers_vae_key_from_ldm_key(target_ldm_key, i=None): + for ldm_key, diffusers_key in vae_ldm_to_diffusers_dict.items(): + if i is not None: + ldm_key = ldm_key.replace("{i}", str(i)) + diffusers_key = diffusers_key.replace("{i}", str(i)) + if ldm_key == target_ldm_key: + return diffusers_key + + if ldm_key in vae_ldm_to_diffusers_dict: + return vae_ldm_to_diffusers_dict[ldm_key] + else: + return None + +# def get_ldm_vae_key_from_diffusers_key(target_diffusers_key): +# for ldm_key, diffusers_key in vae_ldm_to_diffusers_dict.items(): +# if diffusers_key == target_diffusers_key: +# return ldm_key +# return None + +def get_ldm_vae_key_from_diffusers_key(target_diffusers_key): + for ldm_key, diffusers_key in vae_ldm_to_diffusers_dict.items(): + if "{" in diffusers_key: # if we have a placeholder + # escape special characters in the key, and replace the placeholder with a regex group + pattern = re.escape(diffusers_key).replace("\\{i\\}", "(\\d+)") + match = re.match(pattern, target_diffusers_key) + if match: # if we found a match + return ldm_key.format(i=match.group(1)) + elif diffusers_key == target_diffusers_key: + return ldm_key + return None + + +vae_keys_squished_on_diffusers = [ + "decoder.mid_block.attentions.0.to_k.weight", + "decoder.mid_block.attentions.0.to_out.0.weight", + "decoder.mid_block.attentions.0.to_q.weight", + "decoder.mid_block.attentions.0.to_v.weight", + "encoder.mid_block.attentions.0.to_k.weight", + "encoder.mid_block.attentions.0.to_out.0.weight", + "encoder.mid_block.attentions.0.to_q.weight", + "encoder.mid_block.attentions.0.to_v.weight" +] + +def convert_diffusers_back_to_ldm(diffusers_vae): + new_state_dict = OrderedDict() + diffusers_state_dict = diffusers_vae.state_dict() + for key, value in diffusers_state_dict.items(): + val_to_save = value + if key in vae_keys_squished_on_diffusers: + val_to_save = value.clone() + # (512, 512) diffusers and (512, 512, 1, 1) ldm + val_to_save = val_to_save.unsqueeze(-1).unsqueeze(-1) + ldm_key = get_ldm_vae_key_from_diffusers_key(key) + if ldm_key is not None: + new_state_dict[ldm_key] = val_to_save + else: + # for now add current key + new_state_dict[key] = val_to_save + return new_state_dict + + def convert_ldm_vae_checkpoint(checkpoint, config): + mapping = {} # extract state dict for VAE vae_state_dict = {} vae_key = "first_stage_model." @@ -390,6 +721,11 @@ def convert_ldm_vae_checkpoint(checkpoint, config): new_checkpoint = {} + # for key in list(vae_state_dict.keys()): + # diffusers_key = get_diffusers_vae_key_from_ldm_key(key) + # if diffusers_key is not None: + # new_checkpoint[diffusers_key] = vae_state_dict[key] + new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] @@ -411,11 +747,13 @@ def convert_ldm_vae_checkpoint(checkpoint, config): # Retrieves the keys for the encoder down blocks only num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) - down_blocks = {layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)} + down_blocks = {layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in + range(num_down_blocks)} # Retrieves the keys for the decoder up blocks only num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) - up_blocks = {layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)} + up_blocks = {layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in + range(num_up_blocks)} for i in range(num_down_blocks): resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] @@ -424,9 +762,11 @@ def convert_ldm_vae_checkpoint(checkpoint, config): new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( f"encoder.down.{i}.downsample.conv.weight" ) + mapping[f"encoder.down.{i}.downsample.conv.weight"] = f"encoder.down_blocks.{i}.downsamplers.0.conv.weight" new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( f"encoder.down.{i}.downsample.conv.bias" ) + mapping[f"encoder.down.{i}.downsample.conv.bias"] = f"encoder.down_blocks.{i}.downsamplers.0.conv.bias" paths = renew_vae_resnet_paths(resnets) meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} @@ -449,15 +789,18 @@ def convert_ldm_vae_checkpoint(checkpoint, config): for i in range(num_up_blocks): block_id = num_up_blocks - 1 - i - resnets = [key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key] + resnets = [key for key in up_blocks[block_id] if + f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key] if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ f"decoder.up.{block_id}.upsample.conv.weight" ] + mapping[f"decoder.up.{block_id}.upsample.conv.weight"] = f"decoder.up_blocks.{i}.upsamplers.0.conv.weight" new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ f"decoder.up.{block_id}.upsample.conv.bias" ] + mapping[f"decoder.up.{block_id}.upsample.conv.bias"] = f"decoder.up_blocks.{i}.upsamplers.0.conv.bias" paths = renew_vae_resnet_paths(resnets) meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} @@ -548,7 +891,7 @@ def convert_ldm_clip_checkpoint_v1(checkpoint): text_model_dict = {} for key in keys: if key.startswith("cond_stage_model.transformer"): - text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] + text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key] return text_model_dict @@ -677,36 +1020,36 @@ def convert_unet_state_dict_to_sd(v2, unet_state_dict): for j in range(2): # loop over resnets/attentions for downblocks hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." - sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." + sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0." unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) if i < 3: # no attention layers in down_blocks.3 hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." - sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." + sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1." unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) for j in range(3): # loop over resnets/attentions for upblocks hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." - sd_up_res_prefix = f"output_blocks.{3*i + j}.0." + sd_up_res_prefix = f"output_blocks.{3 * i + j}.0." unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) if i > 0: # no attention layers in up_blocks.0 hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." - sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." + sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1." unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) if i < 3: # no downsample in down_blocks.3 hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." - sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." + sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op." unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) # no upsample in up_blocks.3 hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." - sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}." + sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}." unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) hf_mid_atn_prefix = "mid_block.attentions.0." @@ -715,7 +1058,7 @@ def convert_unet_state_dict_to_sd(v2, unet_state_dict): for j in range(2): hf_mid_res_prefix = f"mid_block.resnets.{j}." - sd_mid_res_prefix = f"middle_block.{2*j}." + sd_mid_res_prefix = f"middle_block.{2 * j}." unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) # buyer beware: this is a *brittle* function, @@ -772,20 +1115,20 @@ def convert_vae_state_dict(vae_state_dict): vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." - sd_upsample_prefix = f"up.{3-i}.upsample." + sd_upsample_prefix = f"up.{3 - i}.upsample." vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) # up_blocks have three resnets # also, up blocks in hf are numbered in reverse from sd for j in range(3): hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." - sd_up_prefix = f"decoder.up.{3-i}.block.{j}." + sd_up_prefix = f"decoder.up.{3 - i}.block.{j}." vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) # this part accounts for mid blocks in both the encoder and the decoder for i in range(2): hf_mid_res_prefix = f"mid_block.resnets.{i}." - sd_mid_res_prefix = f"mid.block_{i+1}." + sd_mid_res_prefix = f"mid.block_{i + 1}." vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) vae_conversion_map_attn = [ @@ -850,7 +1193,7 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"): for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS: for key in state_dict.keys(): if key.startswith(rep_from): - new_key = rep_to + key[len(rep_from) :] + new_key = rep_to + key[len(rep_from):] key_reps.append((key, new_key)) for key, new_key in key_reps: @@ -861,7 +1204,8 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"): # TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認 -def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None, unet_use_linear_projection_in_v2=False): +def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None, + unet_use_linear_projection_in_v2=False): _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device) # Convert the UNet2DConditionModel model. @@ -990,7 +1334,8 @@ def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=Fals return new_sd -def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None): +def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, + vae=None): if ckpt_path is not None: # epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path) @@ -1059,7 +1404,8 @@ def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_p return key_count -def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False): +def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, + use_safetensors=False): if pretrained_model_name_or_path is None: # load default settings for v1/v2 if v2: @@ -1177,4 +1523,4 @@ if __name__ == "__main__": for ar in aspect_ratios: if ar in ars: print("error! duplicate ar:", ar) - ars.add(ar) \ No newline at end of file + ars.add(ar) diff --git a/toolkit/losses.py b/toolkit/losses.py new file mode 100644 index 00000000..9c0ae097 --- /dev/null +++ b/toolkit/losses.py @@ -0,0 +1,23 @@ +import torch + + +def total_variation(image): + """ + Compute normalized total variation. + Inputs: + - image: PyTorch Variable of shape (N, C, H, W) + Returns: + - TV: total variation normalized by the number of elements + """ + n_elements = image.shape[1] * image.shape[2] * image.shape[3] + return ((torch.sum(torch.abs(image[:, :, :, :-1] - image[:, :, :, 1:])) + + torch.sum(torch.abs(image[:, :, :-1, :] - image[:, :, 1:, :]))) / n_elements) + + +class ComparativeTotalVariation(torch.nn.Module): + """ + Compute the comparative loss in tv between two images. to match their tv + """ + + def forward(self, pred, target): + return torch.abs(total_variation(pred) - total_variation(target)) diff --git a/toolkit/style.py b/toolkit/style.py index 3c810499..52be08ba 100644 --- a/toolkit/style.py +++ b/toolkit/style.py @@ -127,11 +127,12 @@ class Normalization(nn.Module): def get_style_model_and_losses( single_target=False, - device='cuda' if torch.cuda.is_available() else 'cpu' + device='cuda' if torch.cuda.is_available() else 'cpu', + output_layer_name=None, ): # content_layers = ['conv_4'] # style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5'] - content_layers = ['conv3_2', 'conv4_2'] + content_layers = ['conv4_2'] style_layers = ['conv2_1', 'conv3_1', 'conv4_1'] cnn = models.vgg19(pretrained=True).features.to(device).eval() # normalization module @@ -150,6 +151,8 @@ def get_style_model_and_losses( block = 1 children = list(cnn.children()) + output_layer = None + for layer in children: if isinstance(layer, nn.Conv2d): i += 1 @@ -184,11 +187,16 @@ def get_style_model_and_losses( model.add_module("style_loss_{}_{}".format(block, i), style_loss) style_losses.append(style_loss) + if output_layer_name is not None and name == output_layer_name: + output_layer = layer + # now we trim off the layers after the last content and style losses for i in range(len(model) - 1, -1, -1): if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss): break + if output_layer_name is not None and model[i].name == output_layer_name: + break model = model[:(i + 1)] - return model, style_losses, content_losses + return model, style_losses, content_losses, output_layer