mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 10:41:28 +00:00
Handle conversions back to ldm for saving
This commit is contained in:
@@ -1,16 +1,19 @@
|
||||
# mostly from https://github.com/kohya-ss/sd-scripts/blob/main/library/model_util.py
|
||||
# I am infinitely grateful to @kohya-ss for their amazing work in this field.
|
||||
# This version is updated to handle the latest version of the diffusers library.
|
||||
|
||||
import json
|
||||
# v1: split from train_db_fixed.py.
|
||||
# v2: support safetensors
|
||||
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
|
||||
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
||||
from safetensors.torch import load_file, save_file
|
||||
from collections import OrderedDict
|
||||
|
||||
# DiffUsers版StableDiffusionのモデルパラメータ
|
||||
NUM_TRAIN_TIMESTEPS = 1000
|
||||
@@ -151,7 +154,7 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
||||
|
||||
|
||||
def assign_to_checkpoint(
|
||||
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
|
||||
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
|
||||
):
|
||||
"""
|
||||
This does the final conversion step: take locally converted weights and apply a global renaming
|
||||
@@ -228,6 +231,7 @@ def linear_transformer_to_conv(checkpoint):
|
||||
|
||||
|
||||
def convert_ldm_unet_checkpoint(v2, checkpoint, config):
|
||||
mapping = {}
|
||||
"""
|
||||
Takes a state dict and a config, and returns a converted checkpoint.
|
||||
"""
|
||||
@@ -258,33 +262,40 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
|
||||
# Retrieves the keys for the input blocks only
|
||||
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
|
||||
input_blocks = {
|
||||
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] for layer_id in range(num_input_blocks)
|
||||
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] for layer_id in
|
||||
range(num_input_blocks)
|
||||
}
|
||||
|
||||
# Retrieves the keys for the middle blocks only
|
||||
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
|
||||
middle_blocks = {
|
||||
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key] for layer_id in range(num_middle_blocks)
|
||||
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key] for layer_id in
|
||||
range(num_middle_blocks)
|
||||
}
|
||||
|
||||
# Retrieves the keys for the output blocks only
|
||||
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
|
||||
output_blocks = {
|
||||
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] for layer_id in range(num_output_blocks)
|
||||
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] for layer_id in
|
||||
range(num_output_blocks)
|
||||
}
|
||||
|
||||
for i in range(1, num_input_blocks):
|
||||
block_id = (i - 1) // (config["layers_per_block"] + 1)
|
||||
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
|
||||
|
||||
resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key]
|
||||
resnets = [key for key in input_blocks[i] if
|
||||
f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key]
|
||||
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
|
||||
|
||||
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
|
||||
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
|
||||
f"input_blocks.{i}.0.op.weight"
|
||||
)
|
||||
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(f"input_blocks.{i}.0.op.bias")
|
||||
mapping[f'input_blocks.{i}.0.op.weight'] = f"down_blocks.{block_id}.downsamplers.0.conv.weight"
|
||||
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
|
||||
f"input_blocks.{i}.0.op.bias")
|
||||
mapping[f'input_blocks.{i}.0.op.bias'] = f"down_blocks.{block_id}.downsamplers.0.conv.bias"
|
||||
|
||||
paths = renew_resnet_paths(resnets)
|
||||
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
||||
@@ -293,7 +304,8 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
|
||||
if len(attentions):
|
||||
paths = renew_attention_paths(attentions)
|
||||
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
|
||||
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
||||
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path],
|
||||
config=config)
|
||||
|
||||
resnet_0 = middle_blocks[0]
|
||||
attentions = middle_blocks[1]
|
||||
@@ -307,7 +319,8 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
|
||||
|
||||
attentions_paths = renew_attention_paths(attentions)
|
||||
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
|
||||
assign_to_checkpoint(attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
||||
assign_to_checkpoint(attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path],
|
||||
config=config)
|
||||
|
||||
for i in range(num_output_blocks):
|
||||
block_id = i // (config["layers_per_block"] + 1)
|
||||
@@ -330,7 +343,8 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
|
||||
paths = renew_resnet_paths(resnets)
|
||||
|
||||
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
||||
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
||||
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path],
|
||||
config=config)
|
||||
|
||||
# オリジナル:
|
||||
# if ["conv.weight", "conv.bias"] in output_block_list.values():
|
||||
@@ -359,7 +373,8 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
|
||||
"old": f"output_blocks.{i}.1",
|
||||
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
|
||||
}
|
||||
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
||||
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path],
|
||||
config=config)
|
||||
else:
|
||||
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
|
||||
for path in resnet_0_paths:
|
||||
@@ -373,10 +388,326 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
|
||||
if v2 and not config.get('use_linear_projection', False):
|
||||
linear_transformer_to_conv(new_checkpoint)
|
||||
|
||||
# print("mapping: ", json.dumps(mapping, indent=4))
|
||||
return new_checkpoint
|
||||
|
||||
|
||||
# ldm key: diffusers key
|
||||
vae_ldm_to_diffusers_dict = {
|
||||
"decoder.conv_in.bias": "decoder.conv_in.bias",
|
||||
"decoder.conv_in.weight": "decoder.conv_in.weight",
|
||||
"decoder.conv_out.bias": "decoder.conv_out.bias",
|
||||
"decoder.conv_out.weight": "decoder.conv_out.weight",
|
||||
"decoder.mid.attn_1.k.bias": "decoder.mid_block.attentions.0.to_k.bias",
|
||||
"decoder.mid.attn_1.k.weight": "decoder.mid_block.attentions.0.to_k.weight",
|
||||
"decoder.mid.attn_1.norm.bias": "decoder.mid_block.attentions.0.group_norm.bias",
|
||||
"decoder.mid.attn_1.norm.weight": "decoder.mid_block.attentions.0.group_norm.weight",
|
||||
"decoder.mid.attn_1.proj_out.bias": "decoder.mid_block.attentions.0.to_out.0.bias",
|
||||
"decoder.mid.attn_1.proj_out.weight": "decoder.mid_block.attentions.0.to_out.0.weight",
|
||||
"decoder.mid.attn_1.q.bias": "decoder.mid_block.attentions.0.to_q.bias",
|
||||
"decoder.mid.attn_1.q.weight": "decoder.mid_block.attentions.0.to_q.weight",
|
||||
"decoder.mid.attn_1.v.bias": "decoder.mid_block.attentions.0.to_v.bias",
|
||||
"decoder.mid.attn_1.v.weight": "decoder.mid_block.attentions.0.to_v.weight",
|
||||
"decoder.mid.block_1.conv1.bias": "decoder.mid_block.resnets.0.conv1.bias",
|
||||
"decoder.mid.block_1.conv1.weight": "decoder.mid_block.resnets.0.conv1.weight",
|
||||
"decoder.mid.block_1.conv2.bias": "decoder.mid_block.resnets.0.conv2.bias",
|
||||
"decoder.mid.block_1.conv2.weight": "decoder.mid_block.resnets.0.conv2.weight",
|
||||
"decoder.mid.block_1.norm1.bias": "decoder.mid_block.resnets.0.norm1.bias",
|
||||
"decoder.mid.block_1.norm1.weight": "decoder.mid_block.resnets.0.norm1.weight",
|
||||
"decoder.mid.block_1.norm2.bias": "decoder.mid_block.resnets.0.norm2.bias",
|
||||
"decoder.mid.block_1.norm2.weight": "decoder.mid_block.resnets.0.norm2.weight",
|
||||
"decoder.mid.block_2.conv1.bias": "decoder.mid_block.resnets.1.conv1.bias",
|
||||
"decoder.mid.block_2.conv1.weight": "decoder.mid_block.resnets.1.conv1.weight",
|
||||
"decoder.mid.block_2.conv2.bias": "decoder.mid_block.resnets.1.conv2.bias",
|
||||
"decoder.mid.block_2.conv2.weight": "decoder.mid_block.resnets.1.conv2.weight",
|
||||
"decoder.mid.block_2.norm1.bias": "decoder.mid_block.resnets.1.norm1.bias",
|
||||
"decoder.mid.block_2.norm1.weight": "decoder.mid_block.resnets.1.norm1.weight",
|
||||
"decoder.mid.block_2.norm2.bias": "decoder.mid_block.resnets.1.norm2.bias",
|
||||
"decoder.mid.block_2.norm2.weight": "decoder.mid_block.resnets.1.norm2.weight",
|
||||
"decoder.norm_out.bias": "decoder.conv_norm_out.bias",
|
||||
"decoder.norm_out.weight": "decoder.conv_norm_out.weight",
|
||||
"decoder.up.0.block.0.conv1.bias": "decoder.up_blocks.3.resnets.0.conv1.bias",
|
||||
"decoder.up.0.block.0.conv1.weight": "decoder.up_blocks.3.resnets.0.conv1.weight",
|
||||
"decoder.up.0.block.0.conv2.bias": "decoder.up_blocks.3.resnets.0.conv2.bias",
|
||||
"decoder.up.0.block.0.conv2.weight": "decoder.up_blocks.3.resnets.0.conv2.weight",
|
||||
"decoder.up.0.block.0.nin_shortcut.bias": "decoder.up_blocks.3.resnets.0.conv_shortcut.bias",
|
||||
"decoder.up.0.block.0.nin_shortcut.weight": "decoder.up_blocks.3.resnets.0.conv_shortcut.weight",
|
||||
"decoder.up.0.block.0.norm1.bias": "decoder.up_blocks.3.resnets.0.norm1.bias",
|
||||
"decoder.up.0.block.0.norm1.weight": "decoder.up_blocks.3.resnets.0.norm1.weight",
|
||||
"decoder.up.0.block.0.norm2.bias": "decoder.up_blocks.3.resnets.0.norm2.bias",
|
||||
"decoder.up.0.block.0.norm2.weight": "decoder.up_blocks.3.resnets.0.norm2.weight",
|
||||
"decoder.up.0.block.1.conv1.bias": "decoder.up_blocks.3.resnets.1.conv1.bias",
|
||||
"decoder.up.0.block.1.conv1.weight": "decoder.up_blocks.3.resnets.1.conv1.weight",
|
||||
"decoder.up.0.block.1.conv2.bias": "decoder.up_blocks.3.resnets.1.conv2.bias",
|
||||
"decoder.up.0.block.1.conv2.weight": "decoder.up_blocks.3.resnets.1.conv2.weight",
|
||||
"decoder.up.0.block.1.norm1.bias": "decoder.up_blocks.3.resnets.1.norm1.bias",
|
||||
"decoder.up.0.block.1.norm1.weight": "decoder.up_blocks.3.resnets.1.norm1.weight",
|
||||
"decoder.up.0.block.1.norm2.bias": "decoder.up_blocks.3.resnets.1.norm2.bias",
|
||||
"decoder.up.0.block.1.norm2.weight": "decoder.up_blocks.3.resnets.1.norm2.weight",
|
||||
"decoder.up.0.block.2.conv1.bias": "decoder.up_blocks.3.resnets.2.conv1.bias",
|
||||
"decoder.up.0.block.2.conv1.weight": "decoder.up_blocks.3.resnets.2.conv1.weight",
|
||||
"decoder.up.0.block.2.conv2.bias": "decoder.up_blocks.3.resnets.2.conv2.bias",
|
||||
"decoder.up.0.block.2.conv2.weight": "decoder.up_blocks.3.resnets.2.conv2.weight",
|
||||
"decoder.up.0.block.2.norm1.bias": "decoder.up_blocks.3.resnets.2.norm1.bias",
|
||||
"decoder.up.0.block.2.norm1.weight": "decoder.up_blocks.3.resnets.2.norm1.weight",
|
||||
"decoder.up.0.block.2.norm2.bias": "decoder.up_blocks.3.resnets.2.norm2.bias",
|
||||
"decoder.up.0.block.2.norm2.weight": "decoder.up_blocks.3.resnets.2.norm2.weight",
|
||||
"decoder.up.1.block.0.conv1.bias": "decoder.up_blocks.2.resnets.0.conv1.bias",
|
||||
"decoder.up.1.block.0.conv1.weight": "decoder.up_blocks.2.resnets.0.conv1.weight",
|
||||
"decoder.up.1.block.0.conv2.bias": "decoder.up_blocks.2.resnets.0.conv2.bias",
|
||||
"decoder.up.1.block.0.conv2.weight": "decoder.up_blocks.2.resnets.0.conv2.weight",
|
||||
"decoder.up.1.block.0.nin_shortcut.bias": "decoder.up_blocks.2.resnets.0.conv_shortcut.bias",
|
||||
"decoder.up.1.block.0.nin_shortcut.weight": "decoder.up_blocks.2.resnets.0.conv_shortcut.weight",
|
||||
"decoder.up.1.block.0.norm1.bias": "decoder.up_blocks.2.resnets.0.norm1.bias",
|
||||
"decoder.up.1.block.0.norm1.weight": "decoder.up_blocks.2.resnets.0.norm1.weight",
|
||||
"decoder.up.1.block.0.norm2.bias": "decoder.up_blocks.2.resnets.0.norm2.bias",
|
||||
"decoder.up.1.block.0.norm2.weight": "decoder.up_blocks.2.resnets.0.norm2.weight",
|
||||
"decoder.up.1.block.1.conv1.bias": "decoder.up_blocks.2.resnets.1.conv1.bias",
|
||||
"decoder.up.1.block.1.conv1.weight": "decoder.up_blocks.2.resnets.1.conv1.weight",
|
||||
"decoder.up.1.block.1.conv2.bias": "decoder.up_blocks.2.resnets.1.conv2.bias",
|
||||
"decoder.up.1.block.1.conv2.weight": "decoder.up_blocks.2.resnets.1.conv2.weight",
|
||||
"decoder.up.1.block.1.norm1.bias": "decoder.up_blocks.2.resnets.1.norm1.bias",
|
||||
"decoder.up.1.block.1.norm1.weight": "decoder.up_blocks.2.resnets.1.norm1.weight",
|
||||
"decoder.up.1.block.1.norm2.bias": "decoder.up_blocks.2.resnets.1.norm2.bias",
|
||||
"decoder.up.1.block.1.norm2.weight": "decoder.up_blocks.2.resnets.1.norm2.weight",
|
||||
"decoder.up.1.block.2.conv1.bias": "decoder.up_blocks.2.resnets.2.conv1.bias",
|
||||
"decoder.up.1.block.2.conv1.weight": "decoder.up_blocks.2.resnets.2.conv1.weight",
|
||||
"decoder.up.1.block.2.conv2.bias": "decoder.up_blocks.2.resnets.2.conv2.bias",
|
||||
"decoder.up.1.block.2.conv2.weight": "decoder.up_blocks.2.resnets.2.conv2.weight",
|
||||
"decoder.up.1.block.2.norm1.bias": "decoder.up_blocks.2.resnets.2.norm1.bias",
|
||||
"decoder.up.1.block.2.norm1.weight": "decoder.up_blocks.2.resnets.2.norm1.weight",
|
||||
"decoder.up.1.block.2.norm2.bias": "decoder.up_blocks.2.resnets.2.norm2.bias",
|
||||
"decoder.up.1.block.2.norm2.weight": "decoder.up_blocks.2.resnets.2.norm2.weight",
|
||||
"decoder.up.1.upsample.conv.bias": "decoder.up_blocks.2.upsamplers.0.conv.bias",
|
||||
"decoder.up.1.upsample.conv.weight": "decoder.up_blocks.2.upsamplers.0.conv.weight",
|
||||
"decoder.up.2.block.0.conv1.bias": "decoder.up_blocks.1.resnets.0.conv1.bias",
|
||||
"decoder.up.2.block.0.conv1.weight": "decoder.up_blocks.1.resnets.0.conv1.weight",
|
||||
"decoder.up.2.block.0.conv2.bias": "decoder.up_blocks.1.resnets.0.conv2.bias",
|
||||
"decoder.up.2.block.0.conv2.weight": "decoder.up_blocks.1.resnets.0.conv2.weight",
|
||||
"decoder.up.2.block.0.norm1.bias": "decoder.up_blocks.1.resnets.0.norm1.bias",
|
||||
"decoder.up.2.block.0.norm1.weight": "decoder.up_blocks.1.resnets.0.norm1.weight",
|
||||
"decoder.up.2.block.0.norm2.bias": "decoder.up_blocks.1.resnets.0.norm2.bias",
|
||||
"decoder.up.2.block.0.norm2.weight": "decoder.up_blocks.1.resnets.0.norm2.weight",
|
||||
"decoder.up.2.block.1.conv1.bias": "decoder.up_blocks.1.resnets.1.conv1.bias",
|
||||
"decoder.up.2.block.1.conv1.weight": "decoder.up_blocks.1.resnets.1.conv1.weight",
|
||||
"decoder.up.2.block.1.conv2.bias": "decoder.up_blocks.1.resnets.1.conv2.bias",
|
||||
"decoder.up.2.block.1.conv2.weight": "decoder.up_blocks.1.resnets.1.conv2.weight",
|
||||
"decoder.up.2.block.1.norm1.bias": "decoder.up_blocks.1.resnets.1.norm1.bias",
|
||||
"decoder.up.2.block.1.norm1.weight": "decoder.up_blocks.1.resnets.1.norm1.weight",
|
||||
"decoder.up.2.block.1.norm2.bias": "decoder.up_blocks.1.resnets.1.norm2.bias",
|
||||
"decoder.up.2.block.1.norm2.weight": "decoder.up_blocks.1.resnets.1.norm2.weight",
|
||||
"decoder.up.2.block.2.conv1.bias": "decoder.up_blocks.1.resnets.2.conv1.bias",
|
||||
"decoder.up.2.block.2.conv1.weight": "decoder.up_blocks.1.resnets.2.conv1.weight",
|
||||
"decoder.up.2.block.2.conv2.bias": "decoder.up_blocks.1.resnets.2.conv2.bias",
|
||||
"decoder.up.2.block.2.conv2.weight": "decoder.up_blocks.1.resnets.2.conv2.weight",
|
||||
"decoder.up.2.block.2.norm1.bias": "decoder.up_blocks.1.resnets.2.norm1.bias",
|
||||
"decoder.up.2.block.2.norm1.weight": "decoder.up_blocks.1.resnets.2.norm1.weight",
|
||||
"decoder.up.2.block.2.norm2.bias": "decoder.up_blocks.1.resnets.2.norm2.bias",
|
||||
"decoder.up.2.block.2.norm2.weight": "decoder.up_blocks.1.resnets.2.norm2.weight",
|
||||
"decoder.up.2.upsample.conv.bias": "decoder.up_blocks.1.upsamplers.0.conv.bias",
|
||||
"decoder.up.2.upsample.conv.weight": "decoder.up_blocks.1.upsamplers.0.conv.weight",
|
||||
"decoder.up.3.block.0.conv1.bias": "decoder.up_blocks.0.resnets.0.conv1.bias",
|
||||
"decoder.up.3.block.0.conv1.weight": "decoder.up_blocks.0.resnets.0.conv1.weight",
|
||||
"decoder.up.3.block.0.conv2.bias": "decoder.up_blocks.0.resnets.0.conv2.bias",
|
||||
"decoder.up.3.block.0.conv2.weight": "decoder.up_blocks.0.resnets.0.conv2.weight",
|
||||
"decoder.up.3.block.0.norm1.bias": "decoder.up_blocks.0.resnets.0.norm1.bias",
|
||||
"decoder.up.3.block.0.norm1.weight": "decoder.up_blocks.0.resnets.0.norm1.weight",
|
||||
"decoder.up.3.block.0.norm2.bias": "decoder.up_blocks.0.resnets.0.norm2.bias",
|
||||
"decoder.up.3.block.0.norm2.weight": "decoder.up_blocks.0.resnets.0.norm2.weight",
|
||||
"decoder.up.3.block.1.conv1.bias": "decoder.up_blocks.0.resnets.1.conv1.bias",
|
||||
"decoder.up.3.block.1.conv1.weight": "decoder.up_blocks.0.resnets.1.conv1.weight",
|
||||
"decoder.up.3.block.1.conv2.bias": "decoder.up_blocks.0.resnets.1.conv2.bias",
|
||||
"decoder.up.3.block.1.conv2.weight": "decoder.up_blocks.0.resnets.1.conv2.weight",
|
||||
"decoder.up.3.block.1.norm1.bias": "decoder.up_blocks.0.resnets.1.norm1.bias",
|
||||
"decoder.up.3.block.1.norm1.weight": "decoder.up_blocks.0.resnets.1.norm1.weight",
|
||||
"decoder.up.3.block.1.norm2.bias": "decoder.up_blocks.0.resnets.1.norm2.bias",
|
||||
"decoder.up.3.block.1.norm2.weight": "decoder.up_blocks.0.resnets.1.norm2.weight",
|
||||
"decoder.up.3.block.2.conv1.bias": "decoder.up_blocks.0.resnets.2.conv1.bias",
|
||||
"decoder.up.3.block.2.conv1.weight": "decoder.up_blocks.0.resnets.2.conv1.weight",
|
||||
"decoder.up.3.block.2.conv2.bias": "decoder.up_blocks.0.resnets.2.conv2.bias",
|
||||
"decoder.up.3.block.2.conv2.weight": "decoder.up_blocks.0.resnets.2.conv2.weight",
|
||||
"decoder.up.3.block.2.norm1.bias": "decoder.up_blocks.0.resnets.2.norm1.bias",
|
||||
"decoder.up.3.block.2.norm1.weight": "decoder.up_blocks.0.resnets.2.norm1.weight",
|
||||
"decoder.up.3.block.2.norm2.bias": "decoder.up_blocks.0.resnets.2.norm2.bias",
|
||||
"decoder.up.3.block.2.norm2.weight": "decoder.up_blocks.0.resnets.2.norm2.weight",
|
||||
"decoder.up.3.upsample.conv.bias": "decoder.up_blocks.0.upsamplers.0.conv.bias",
|
||||
"decoder.up.3.upsample.conv.weight": "decoder.up_blocks.0.upsamplers.0.conv.weight",
|
||||
"encoder.conv_in.bias": "encoder.conv_in.bias",
|
||||
"encoder.conv_in.weight": "encoder.conv_in.weight",
|
||||
"encoder.conv_out.bias": "encoder.conv_out.bias",
|
||||
"encoder.conv_out.weight": "encoder.conv_out.weight",
|
||||
"encoder.down.0.block.0.conv1.bias": "encoder.down_blocks.0.resnets.0.conv1.bias",
|
||||
"encoder.down.0.block.0.conv1.weight": "encoder.down_blocks.0.resnets.0.conv1.weight",
|
||||
"encoder.down.0.block.0.conv2.bias": "encoder.down_blocks.0.resnets.0.conv2.bias",
|
||||
"encoder.down.0.block.0.conv2.weight": "encoder.down_blocks.0.resnets.0.conv2.weight",
|
||||
"encoder.down.0.block.0.norm1.bias": "encoder.down_blocks.0.resnets.0.norm1.bias",
|
||||
"encoder.down.0.block.0.norm1.weight": "encoder.down_blocks.0.resnets.0.norm1.weight",
|
||||
"encoder.down.0.block.0.norm2.bias": "encoder.down_blocks.0.resnets.0.norm2.bias",
|
||||
"encoder.down.0.block.0.norm2.weight": "encoder.down_blocks.0.resnets.0.norm2.weight",
|
||||
"encoder.down.0.block.1.conv1.bias": "encoder.down_blocks.0.resnets.1.conv1.bias",
|
||||
"encoder.down.0.block.1.conv1.weight": "encoder.down_blocks.0.resnets.1.conv1.weight",
|
||||
"encoder.down.0.block.1.conv2.bias": "encoder.down_blocks.0.resnets.1.conv2.bias",
|
||||
"encoder.down.0.block.1.conv2.weight": "encoder.down_blocks.0.resnets.1.conv2.weight",
|
||||
"encoder.down.0.block.1.norm1.bias": "encoder.down_blocks.0.resnets.1.norm1.bias",
|
||||
"encoder.down.0.block.1.norm1.weight": "encoder.down_blocks.0.resnets.1.norm1.weight",
|
||||
"encoder.down.0.block.1.norm2.bias": "encoder.down_blocks.0.resnets.1.norm2.bias",
|
||||
"encoder.down.0.block.1.norm2.weight": "encoder.down_blocks.0.resnets.1.norm2.weight",
|
||||
"encoder.down.0.downsample.conv.bias": "encoder.down_blocks.0.downsamplers.0.conv.bias",
|
||||
"encoder.down.0.downsample.conv.weight": "encoder.down_blocks.0.downsamplers.0.conv.weight",
|
||||
"encoder.down.1.block.0.conv1.bias": "encoder.down_blocks.1.resnets.0.conv1.bias",
|
||||
"encoder.down.1.block.0.conv1.weight": "encoder.down_blocks.1.resnets.0.conv1.weight",
|
||||
"encoder.down.1.block.0.conv2.bias": "encoder.down_blocks.1.resnets.0.conv2.bias",
|
||||
"encoder.down.1.block.0.conv2.weight": "encoder.down_blocks.1.resnets.0.conv2.weight",
|
||||
"encoder.down.1.block.0.nin_shortcut.bias": "encoder.down_blocks.1.resnets.0.conv_shortcut.bias",
|
||||
"encoder.down.1.block.0.nin_shortcut.weight": "encoder.down_blocks.1.resnets.0.conv_shortcut.weight",
|
||||
"encoder.down.1.block.0.norm1.bias": "encoder.down_blocks.1.resnets.0.norm1.bias",
|
||||
"encoder.down.1.block.0.norm1.weight": "encoder.down_blocks.1.resnets.0.norm1.weight",
|
||||
"encoder.down.1.block.0.norm2.bias": "encoder.down_blocks.1.resnets.0.norm2.bias",
|
||||
"encoder.down.1.block.0.norm2.weight": "encoder.down_blocks.1.resnets.0.norm2.weight",
|
||||
"encoder.down.1.block.1.conv1.bias": "encoder.down_blocks.1.resnets.1.conv1.bias",
|
||||
"encoder.down.1.block.1.conv1.weight": "encoder.down_blocks.1.resnets.1.conv1.weight",
|
||||
"encoder.down.1.block.1.conv2.bias": "encoder.down_blocks.1.resnets.1.conv2.bias",
|
||||
"encoder.down.1.block.1.conv2.weight": "encoder.down_blocks.1.resnets.1.conv2.weight",
|
||||
"encoder.down.1.block.1.norm1.bias": "encoder.down_blocks.1.resnets.1.norm1.bias",
|
||||
"encoder.down.1.block.1.norm1.weight": "encoder.down_blocks.1.resnets.1.norm1.weight",
|
||||
"encoder.down.1.block.1.norm2.bias": "encoder.down_blocks.1.resnets.1.norm2.bias",
|
||||
"encoder.down.1.block.1.norm2.weight": "encoder.down_blocks.1.resnets.1.norm2.weight",
|
||||
"encoder.down.1.downsample.conv.bias": "encoder.down_blocks.1.downsamplers.0.conv.bias",
|
||||
"encoder.down.1.downsample.conv.weight": "encoder.down_blocks.1.downsamplers.0.conv.weight",
|
||||
"encoder.down.2.block.0.conv1.bias": "encoder.down_blocks.2.resnets.0.conv1.bias",
|
||||
"encoder.down.2.block.0.conv1.weight": "encoder.down_blocks.2.resnets.0.conv1.weight",
|
||||
"encoder.down.2.block.0.conv2.bias": "encoder.down_blocks.2.resnets.0.conv2.bias",
|
||||
"encoder.down.2.block.0.conv2.weight": "encoder.down_blocks.2.resnets.0.conv2.weight",
|
||||
"encoder.down.2.block.0.nin_shortcut.bias": "encoder.down_blocks.2.resnets.0.conv_shortcut.bias",
|
||||
"encoder.down.2.block.0.nin_shortcut.weight": "encoder.down_blocks.2.resnets.0.conv_shortcut.weight",
|
||||
"encoder.down.2.block.0.norm1.bias": "encoder.down_blocks.2.resnets.0.norm1.bias",
|
||||
"encoder.down.2.block.0.norm1.weight": "encoder.down_blocks.2.resnets.0.norm1.weight",
|
||||
"encoder.down.2.block.0.norm2.bias": "encoder.down_blocks.2.resnets.0.norm2.bias",
|
||||
"encoder.down.2.block.0.norm2.weight": "encoder.down_blocks.2.resnets.0.norm2.weight",
|
||||
"encoder.down.2.block.1.conv1.bias": "encoder.down_blocks.2.resnets.1.conv1.bias",
|
||||
"encoder.down.2.block.1.conv1.weight": "encoder.down_blocks.2.resnets.1.conv1.weight",
|
||||
"encoder.down.2.block.1.conv2.bias": "encoder.down_blocks.2.resnets.1.conv2.bias",
|
||||
"encoder.down.2.block.1.conv2.weight": "encoder.down_blocks.2.resnets.1.conv2.weight",
|
||||
"encoder.down.2.block.1.norm1.bias": "encoder.down_blocks.2.resnets.1.norm1.bias",
|
||||
"encoder.down.2.block.1.norm1.weight": "encoder.down_blocks.2.resnets.1.norm1.weight",
|
||||
"encoder.down.2.block.1.norm2.bias": "encoder.down_blocks.2.resnets.1.norm2.bias",
|
||||
"encoder.down.2.block.1.norm2.weight": "encoder.down_blocks.2.resnets.1.norm2.weight",
|
||||
"encoder.down.2.downsample.conv.bias": "encoder.down_blocks.2.downsamplers.0.conv.bias",
|
||||
"encoder.down.2.downsample.conv.weight": "encoder.down_blocks.2.downsamplers.0.conv.weight",
|
||||
"encoder.down.3.block.0.conv1.bias": "encoder.down_blocks.3.resnets.0.conv1.bias",
|
||||
"encoder.down.3.block.0.conv1.weight": "encoder.down_blocks.3.resnets.0.conv1.weight",
|
||||
"encoder.down.3.block.0.conv2.bias": "encoder.down_blocks.3.resnets.0.conv2.bias",
|
||||
"encoder.down.3.block.0.conv2.weight": "encoder.down_blocks.3.resnets.0.conv2.weight",
|
||||
"encoder.down.3.block.0.norm1.bias": "encoder.down_blocks.3.resnets.0.norm1.bias",
|
||||
"encoder.down.3.block.0.norm1.weight": "encoder.down_blocks.3.resnets.0.norm1.weight",
|
||||
"encoder.down.3.block.0.norm2.bias": "encoder.down_blocks.3.resnets.0.norm2.bias",
|
||||
"encoder.down.3.block.0.norm2.weight": "encoder.down_blocks.3.resnets.0.norm2.weight",
|
||||
"encoder.down.3.block.1.conv1.bias": "encoder.down_blocks.3.resnets.1.conv1.bias",
|
||||
"encoder.down.3.block.1.conv1.weight": "encoder.down_blocks.3.resnets.1.conv1.weight",
|
||||
"encoder.down.3.block.1.conv2.bias": "encoder.down_blocks.3.resnets.1.conv2.bias",
|
||||
"encoder.down.3.block.1.conv2.weight": "encoder.down_blocks.3.resnets.1.conv2.weight",
|
||||
"encoder.down.3.block.1.norm1.bias": "encoder.down_blocks.3.resnets.1.norm1.bias",
|
||||
"encoder.down.3.block.1.norm1.weight": "encoder.down_blocks.3.resnets.1.norm1.weight",
|
||||
"encoder.down.3.block.1.norm2.bias": "encoder.down_blocks.3.resnets.1.norm2.bias",
|
||||
"encoder.down.3.block.1.norm2.weight": "encoder.down_blocks.3.resnets.1.norm2.weight",
|
||||
"encoder.mid.attn_1.k.bias": "encoder.mid_block.attentions.0.to_k.bias",
|
||||
"encoder.mid.attn_1.k.weight": "encoder.mid_block.attentions.0.to_k.weight",
|
||||
"encoder.mid.attn_1.norm.bias": "encoder.mid_block.attentions.0.group_norm.bias",
|
||||
"encoder.mid.attn_1.norm.weight": "encoder.mid_block.attentions.0.group_norm.weight",
|
||||
"encoder.mid.attn_1.proj_out.bias": "encoder.mid_block.attentions.0.to_out.0.bias",
|
||||
"encoder.mid.attn_1.proj_out.weight": "encoder.mid_block.attentions.0.to_out.0.weight",
|
||||
"encoder.mid.attn_1.q.bias": "encoder.mid_block.attentions.0.to_q.bias",
|
||||
"encoder.mid.attn_1.q.weight": "encoder.mid_block.attentions.0.to_q.weight",
|
||||
"encoder.mid.attn_1.v.bias": "encoder.mid_block.attentions.0.to_v.bias",
|
||||
"encoder.mid.attn_1.v.weight": "encoder.mid_block.attentions.0.to_v.weight",
|
||||
"encoder.mid.block_1.conv1.bias": "encoder.mid_block.resnets.0.conv1.bias",
|
||||
"encoder.mid.block_1.conv1.weight": "encoder.mid_block.resnets.0.conv1.weight",
|
||||
"encoder.mid.block_1.conv2.bias": "encoder.mid_block.resnets.0.conv2.bias",
|
||||
"encoder.mid.block_1.conv2.weight": "encoder.mid_block.resnets.0.conv2.weight",
|
||||
"encoder.mid.block_1.norm1.bias": "encoder.mid_block.resnets.0.norm1.bias",
|
||||
"encoder.mid.block_1.norm1.weight": "encoder.mid_block.resnets.0.norm1.weight",
|
||||
"encoder.mid.block_1.norm2.bias": "encoder.mid_block.resnets.0.norm2.bias",
|
||||
"encoder.mid.block_1.norm2.weight": "encoder.mid_block.resnets.0.norm2.weight",
|
||||
"encoder.mid.block_2.conv1.bias": "encoder.mid_block.resnets.1.conv1.bias",
|
||||
"encoder.mid.block_2.conv1.weight": "encoder.mid_block.resnets.1.conv1.weight",
|
||||
"encoder.mid.block_2.conv2.bias": "encoder.mid_block.resnets.1.conv2.bias",
|
||||
"encoder.mid.block_2.conv2.weight": "encoder.mid_block.resnets.1.conv2.weight",
|
||||
"encoder.mid.block_2.norm1.bias": "encoder.mid_block.resnets.1.norm1.bias",
|
||||
"encoder.mid.block_2.norm1.weight": "encoder.mid_block.resnets.1.norm1.weight",
|
||||
"encoder.mid.block_2.norm2.bias": "encoder.mid_block.resnets.1.norm2.bias",
|
||||
"encoder.mid.block_2.norm2.weight": "encoder.mid_block.resnets.1.norm2.weight",
|
||||
"encoder.norm_out.bias": "encoder.conv_norm_out.bias",
|
||||
"encoder.norm_out.weight": "encoder.conv_norm_out.weight",
|
||||
"post_quant_conv.bias": "post_quant_conv.bias",
|
||||
"post_quant_conv.weight": "post_quant_conv.weight",
|
||||
"quant_conv.bias": "quant_conv.bias",
|
||||
"quant_conv.weight": "quant_conv.weight"
|
||||
}
|
||||
|
||||
|
||||
def get_diffusers_vae_key_from_ldm_key(target_ldm_key, i=None):
|
||||
for ldm_key, diffusers_key in vae_ldm_to_diffusers_dict.items():
|
||||
if i is not None:
|
||||
ldm_key = ldm_key.replace("{i}", str(i))
|
||||
diffusers_key = diffusers_key.replace("{i}", str(i))
|
||||
if ldm_key == target_ldm_key:
|
||||
return diffusers_key
|
||||
|
||||
if ldm_key in vae_ldm_to_diffusers_dict:
|
||||
return vae_ldm_to_diffusers_dict[ldm_key]
|
||||
else:
|
||||
return None
|
||||
|
||||
# def get_ldm_vae_key_from_diffusers_key(target_diffusers_key):
|
||||
# for ldm_key, diffusers_key in vae_ldm_to_diffusers_dict.items():
|
||||
# if diffusers_key == target_diffusers_key:
|
||||
# return ldm_key
|
||||
# return None
|
||||
|
||||
def get_ldm_vae_key_from_diffusers_key(target_diffusers_key):
|
||||
for ldm_key, diffusers_key in vae_ldm_to_diffusers_dict.items():
|
||||
if "{" in diffusers_key: # if we have a placeholder
|
||||
# escape special characters in the key, and replace the placeholder with a regex group
|
||||
pattern = re.escape(diffusers_key).replace("\\{i\\}", "(\\d+)")
|
||||
match = re.match(pattern, target_diffusers_key)
|
||||
if match: # if we found a match
|
||||
return ldm_key.format(i=match.group(1))
|
||||
elif diffusers_key == target_diffusers_key:
|
||||
return ldm_key
|
||||
return None
|
||||
|
||||
|
||||
vae_keys_squished_on_diffusers = [
|
||||
"decoder.mid_block.attentions.0.to_k.weight",
|
||||
"decoder.mid_block.attentions.0.to_out.0.weight",
|
||||
"decoder.mid_block.attentions.0.to_q.weight",
|
||||
"decoder.mid_block.attentions.0.to_v.weight",
|
||||
"encoder.mid_block.attentions.0.to_k.weight",
|
||||
"encoder.mid_block.attentions.0.to_out.0.weight",
|
||||
"encoder.mid_block.attentions.0.to_q.weight",
|
||||
"encoder.mid_block.attentions.0.to_v.weight"
|
||||
]
|
||||
|
||||
def convert_diffusers_back_to_ldm(diffusers_vae):
|
||||
new_state_dict = OrderedDict()
|
||||
diffusers_state_dict = diffusers_vae.state_dict()
|
||||
for key, value in diffusers_state_dict.items():
|
||||
val_to_save = value
|
||||
if key in vae_keys_squished_on_diffusers:
|
||||
val_to_save = value.clone()
|
||||
# (512, 512) diffusers and (512, 512, 1, 1) ldm
|
||||
val_to_save = val_to_save.unsqueeze(-1).unsqueeze(-1)
|
||||
ldm_key = get_ldm_vae_key_from_diffusers_key(key)
|
||||
if ldm_key is not None:
|
||||
new_state_dict[ldm_key] = val_to_save
|
||||
else:
|
||||
# for now add current key
|
||||
new_state_dict[key] = val_to_save
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def convert_ldm_vae_checkpoint(checkpoint, config):
|
||||
mapping = {}
|
||||
# extract state dict for VAE
|
||||
vae_state_dict = {}
|
||||
vae_key = "first_stage_model."
|
||||
@@ -390,6 +721,11 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
|
||||
|
||||
new_checkpoint = {}
|
||||
|
||||
# for key in list(vae_state_dict.keys()):
|
||||
# diffusers_key = get_diffusers_vae_key_from_ldm_key(key)
|
||||
# if diffusers_key is not None:
|
||||
# new_checkpoint[diffusers_key] = vae_state_dict[key]
|
||||
|
||||
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
||||
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
|
||||
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
|
||||
@@ -411,11 +747,13 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
|
||||
|
||||
# Retrieves the keys for the encoder down blocks only
|
||||
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
|
||||
down_blocks = {layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)}
|
||||
down_blocks = {layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in
|
||||
range(num_down_blocks)}
|
||||
|
||||
# Retrieves the keys for the decoder up blocks only
|
||||
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
|
||||
up_blocks = {layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)}
|
||||
up_blocks = {layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in
|
||||
range(num_up_blocks)}
|
||||
|
||||
for i in range(num_down_blocks):
|
||||
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
|
||||
@@ -424,9 +762,11 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
|
||||
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
|
||||
f"encoder.down.{i}.downsample.conv.weight"
|
||||
)
|
||||
mapping[f"encoder.down.{i}.downsample.conv.weight"] = f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"
|
||||
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
|
||||
f"encoder.down.{i}.downsample.conv.bias"
|
||||
)
|
||||
mapping[f"encoder.down.{i}.downsample.conv.bias"] = f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"
|
||||
|
||||
paths = renew_vae_resnet_paths(resnets)
|
||||
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
|
||||
@@ -449,15 +789,18 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
|
||||
|
||||
for i in range(num_up_blocks):
|
||||
block_id = num_up_blocks - 1 - i
|
||||
resnets = [key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key]
|
||||
resnets = [key for key in up_blocks[block_id] if
|
||||
f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key]
|
||||
|
||||
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
||||
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
|
||||
f"decoder.up.{block_id}.upsample.conv.weight"
|
||||
]
|
||||
mapping[f"decoder.up.{block_id}.upsample.conv.weight"] = f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"
|
||||
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
|
||||
f"decoder.up.{block_id}.upsample.conv.bias"
|
||||
]
|
||||
mapping[f"decoder.up.{block_id}.upsample.conv.bias"] = f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"
|
||||
|
||||
paths = renew_vae_resnet_paths(resnets)
|
||||
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
|
||||
@@ -548,7 +891,7 @@ def convert_ldm_clip_checkpoint_v1(checkpoint):
|
||||
text_model_dict = {}
|
||||
for key in keys:
|
||||
if key.startswith("cond_stage_model.transformer"):
|
||||
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
|
||||
text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key]
|
||||
return text_model_dict
|
||||
|
||||
|
||||
@@ -677,36 +1020,36 @@ def convert_unet_state_dict_to_sd(v2, unet_state_dict):
|
||||
for j in range(2):
|
||||
# loop over resnets/attentions for downblocks
|
||||
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
||||
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
||||
sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
|
||||
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no attention layers in down_blocks.3
|
||||
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
||||
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
||||
sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
|
||||
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
||||
|
||||
for j in range(3):
|
||||
# loop over resnets/attentions for upblocks
|
||||
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
||||
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
||||
sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
|
||||
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
||||
|
||||
if i > 0:
|
||||
# no attention layers in up_blocks.0
|
||||
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
||||
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
||||
sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
|
||||
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no downsample in down_blocks.3
|
||||
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
||||
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
||||
sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
|
||||
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||
|
||||
# no upsample in up_blocks.3
|
||||
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
|
||||
sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}."
|
||||
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||
|
||||
hf_mid_atn_prefix = "mid_block.attentions.0."
|
||||
@@ -715,7 +1058,7 @@ def convert_unet_state_dict_to_sd(v2, unet_state_dict):
|
||||
|
||||
for j in range(2):
|
||||
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
||||
sd_mid_res_prefix = f"middle_block.{2*j}."
|
||||
sd_mid_res_prefix = f"middle_block.{2 * j}."
|
||||
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||
|
||||
# buyer beware: this is a *brittle* function,
|
||||
@@ -772,20 +1115,20 @@ def convert_vae_state_dict(vae_state_dict):
|
||||
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||
|
||||
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||
sd_upsample_prefix = f"up.{3-i}.upsample."
|
||||
sd_upsample_prefix = f"up.{3 - i}.upsample."
|
||||
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||
|
||||
# up_blocks have three resnets
|
||||
# also, up blocks in hf are numbered in reverse from sd
|
||||
for j in range(3):
|
||||
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
|
||||
sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
|
||||
sd_up_prefix = f"decoder.up.{3 - i}.block.{j}."
|
||||
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
|
||||
|
||||
# this part accounts for mid blocks in both the encoder and the decoder
|
||||
for i in range(2):
|
||||
hf_mid_res_prefix = f"mid_block.resnets.{i}."
|
||||
sd_mid_res_prefix = f"mid.block_{i+1}."
|
||||
sd_mid_res_prefix = f"mid.block_{i + 1}."
|
||||
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||
|
||||
vae_conversion_map_attn = [
|
||||
@@ -850,7 +1193,7 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"):
|
||||
for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
|
||||
for key in state_dict.keys():
|
||||
if key.startswith(rep_from):
|
||||
new_key = rep_to + key[len(rep_from) :]
|
||||
new_key = rep_to + key[len(rep_from):]
|
||||
key_reps.append((key, new_key))
|
||||
|
||||
for key, new_key in key_reps:
|
||||
@@ -861,7 +1204,8 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"):
|
||||
|
||||
|
||||
# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
|
||||
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None, unet_use_linear_projection_in_v2=False):
|
||||
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None,
|
||||
unet_use_linear_projection_in_v2=False):
|
||||
_, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device)
|
||||
|
||||
# Convert the UNet2DConditionModel model.
|
||||
@@ -990,7 +1334,8 @@ def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=Fals
|
||||
return new_sd
|
||||
|
||||
|
||||
def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None):
|
||||
def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None,
|
||||
vae=None):
|
||||
if ckpt_path is not None:
|
||||
# epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
|
||||
checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
|
||||
@@ -1059,7 +1404,8 @@ def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_p
|
||||
return key_count
|
||||
|
||||
|
||||
def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False):
|
||||
def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None,
|
||||
use_safetensors=False):
|
||||
if pretrained_model_name_or_path is None:
|
||||
# load default settings for v1/v2
|
||||
if v2:
|
||||
@@ -1177,4 +1523,4 @@ if __name__ == "__main__":
|
||||
for ar in aspect_ratios:
|
||||
if ar in ars:
|
||||
print("error! duplicate ar:", ar)
|
||||
ars.add(ar)
|
||||
ars.add(ar)
|
||||
|
||||
23
toolkit/losses.py
Normal file
23
toolkit/losses.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import torch
|
||||
|
||||
|
||||
def total_variation(image):
|
||||
"""
|
||||
Compute normalized total variation.
|
||||
Inputs:
|
||||
- image: PyTorch Variable of shape (N, C, H, W)
|
||||
Returns:
|
||||
- TV: total variation normalized by the number of elements
|
||||
"""
|
||||
n_elements = image.shape[1] * image.shape[2] * image.shape[3]
|
||||
return ((torch.sum(torch.abs(image[:, :, :, :-1] - image[:, :, :, 1:])) +
|
||||
torch.sum(torch.abs(image[:, :, :-1, :] - image[:, :, 1:, :]))) / n_elements)
|
||||
|
||||
|
||||
class ComparativeTotalVariation(torch.nn.Module):
|
||||
"""
|
||||
Compute the comparative loss in tv between two images. to match their tv
|
||||
"""
|
||||
|
||||
def forward(self, pred, target):
|
||||
return torch.abs(total_variation(pred) - total_variation(target))
|
||||
@@ -127,11 +127,12 @@ class Normalization(nn.Module):
|
||||
|
||||
def get_style_model_and_losses(
|
||||
single_target=False,
|
||||
device='cuda' if torch.cuda.is_available() else 'cpu'
|
||||
device='cuda' if torch.cuda.is_available() else 'cpu',
|
||||
output_layer_name=None,
|
||||
):
|
||||
# content_layers = ['conv_4']
|
||||
# style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
|
||||
content_layers = ['conv3_2', 'conv4_2']
|
||||
content_layers = ['conv4_2']
|
||||
style_layers = ['conv2_1', 'conv3_1', 'conv4_1']
|
||||
cnn = models.vgg19(pretrained=True).features.to(device).eval()
|
||||
# normalization module
|
||||
@@ -150,6 +151,8 @@ def get_style_model_and_losses(
|
||||
block = 1
|
||||
children = list(cnn.children())
|
||||
|
||||
output_layer = None
|
||||
|
||||
for layer in children:
|
||||
if isinstance(layer, nn.Conv2d):
|
||||
i += 1
|
||||
@@ -184,11 +187,16 @@ def get_style_model_and_losses(
|
||||
model.add_module("style_loss_{}_{}".format(block, i), style_loss)
|
||||
style_losses.append(style_loss)
|
||||
|
||||
if output_layer_name is not None and name == output_layer_name:
|
||||
output_layer = layer
|
||||
|
||||
# now we trim off the layers after the last content and style losses
|
||||
for i in range(len(model) - 1, -1, -1):
|
||||
if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss):
|
||||
break
|
||||
if output_layer_name is not None and model[i].name == output_layer_name:
|
||||
break
|
||||
|
||||
model = model[:(i + 1)]
|
||||
|
||||
return model, style_losses, content_losses
|
||||
return model, style_losses, content_losses, output_layer
|
||||
|
||||
Reference in New Issue
Block a user