Added refiner fine tuning. Works, but needs some polish.

This commit is contained in:
Jaret Burkett
2023-11-05 17:15:03 -07:00
parent 8a9e8f708f
commit 93ea955d7c
14 changed files with 4541 additions and 128 deletions

View File

@@ -6,6 +6,8 @@ import os
# add project root to sys path
import sys
from diffusers import DiffusionPipeline, StableDiffusionXLPipeline
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import torch
@@ -50,6 +52,7 @@ parser.add_argument(
parser.add_argument('--name', type=str, default='stable_diffusion', help='name for mapping to make')
parser.add_argument('--sdxl', action='store_true', help='is sdxl model')
parser.add_argument('--refiner', action='store_true', help='is refiner model')
parser.add_argument('--ssd', action='store_true', help='is ssd model')
parser.add_argument('--sd2', action='store_true', help='is sd 2 model')
@@ -61,29 +64,68 @@ find_matches = False
print(f'Loading diffusers model')
ignore_ldm_begins_with = []
diffusers_file_path = file_path
if args.ssd:
diffusers_file_path = "segmind/SSD-1B"
diffusers_model_config = ModelConfig(
name_or_path=diffusers_file_path,
is_xl=args.sdxl,
is_v2=args.sd2,
is_ssd=args.ssd,
dtype=dtype,
)
diffusers_sd = StableDiffusion(
model_config=diffusers_model_config,
device=device,
dtype=dtype,
)
diffusers_sd.load_model()
# delete things we dont need
del diffusers_sd.tokenizer
flush()
if args.refiner:
diffusers_file_path = "stabilityai/stable-diffusion-xl-refiner-1.0"
if not args.refiner:
diffusers_model_config = ModelConfig(
name_or_path=diffusers_file_path,
is_xl=args.sdxl,
is_v2=args.sd2,
is_ssd=args.ssd,
dtype=dtype,
)
diffusers_sd = StableDiffusion(
model_config=diffusers_model_config,
device=device,
dtype=dtype,
)
diffusers_sd.load_model()
# delete things we dont need
del diffusers_sd.tokenizer
flush()
print(f'Loading ldm model')
diffusers_state_dict = diffusers_sd.state_dict()
else:
# refiner wont work directly with stable diffusion
# so we need to load the model and then load the state dict
diffusers_pipeline = StableDiffusionXLPipeline.from_single_file(
file_path,
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16",
).to(device)
SD_PREFIX_VAE = "vae"
SD_PREFIX_UNET = "unet"
SD_PREFIX_REFINER_UNET = "refiner_unet"
SD_PREFIX_TEXT_ENCODER = "te"
SD_PREFIX_TEXT_ENCODER1 = "te0"
SD_PREFIX_TEXT_ENCODER2 = "te1"
diffusers_state_dict = OrderedDict()
for k, v in diffusers_pipeline.vae.state_dict().items():
new_key = k if k.startswith(f"{SD_PREFIX_VAE}") else f"{SD_PREFIX_VAE}_{k}"
diffusers_state_dict[new_key] = v
for k, v in diffusers_pipeline.text_encoder_2.state_dict().items():
new_key = k if k.startswith(f"{SD_PREFIX_TEXT_ENCODER2}_") else f"{SD_PREFIX_TEXT_ENCODER2}_{k}"
diffusers_state_dict[new_key] = v
for k, v in diffusers_pipeline.unet.state_dict().items():
new_key = k if k.startswith(f"{SD_PREFIX_UNET}_") else f"{SD_PREFIX_UNET}_{k}"
diffusers_state_dict[new_key] = v
# add ignore ones as we are only going to focus on unet and copy the rest
# ignore_ldm_begins_with = ["conditioner.", "first_stage_model."]
print(f'Loading ldm model')
diffusers_state_dict = diffusers_sd.state_dict()
diffusers_dict_keys = list(diffusers_state_dict.keys())
ldm_state_dict = load_file(file_path)
@@ -113,6 +155,12 @@ if args.sdxl or args.ssd:
proj_pattern_weight = r"conditioner\.embedders\.1\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_weight"
proj_pattern_bias = r"conditioner\.embedders\.1\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_bias"
text_proj_layer = "conditioner.embedders.1.model.text_projection"
if args.refiner:
te_suffix = '1'
ldm_res_block_prefix = "conditioner.embedders.0.model.transformer.resblocks"
proj_pattern_weight = r"conditioner\.embedders\.0\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_weight"
proj_pattern_bias = r"conditioner\.embedders\.0\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_bias"
text_proj_layer = "conditioner.embedders.0.model.text_projection"
if args.sd2:
te_suffix = ''
ldm_res_block_prefix = "cond_stage_model.model.transformer.resblocks"
@@ -120,7 +168,7 @@ if args.sd2:
proj_pattern_bias = r"cond_stage_model\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_bias"
text_proj_layer = "cond_stage_model.model.text_projection"
if args.sdxl or args.sd2 or args.ssd:
if args.sdxl or args.sd2 or args.ssd or args.refiner:
if "conditioner.embedders.1.model.text_projection" in ldm_dict_keys:
# d_model = int(checkpoint[prefix + "text_projection"].shape[0]))
d_model = int(ldm_state_dict["conditioner.embedders.1.model.text_projection"].shape[0])
@@ -297,6 +345,8 @@ if args.sdxl:
name += '_sdxl'
elif args.ssd:
name += '_ssd'
elif args.refiner:
name += '_refiner'
elif args.sd2:
name += '_sd2'
else: