mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-13 14:39:50 +00:00
Added refiner fine tuning. Works, but needs some polish.
This commit is contained in:
@@ -6,6 +6,8 @@ import os
|
||||
# add project root to sys path
|
||||
import sys
|
||||
|
||||
from diffusers import DiffusionPipeline, StableDiffusionXLPipeline
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import torch
|
||||
@@ -50,6 +52,7 @@ parser.add_argument(
|
||||
|
||||
parser.add_argument('--name', type=str, default='stable_diffusion', help='name for mapping to make')
|
||||
parser.add_argument('--sdxl', action='store_true', help='is sdxl model')
|
||||
parser.add_argument('--refiner', action='store_true', help='is refiner model')
|
||||
parser.add_argument('--ssd', action='store_true', help='is ssd model')
|
||||
parser.add_argument('--sd2', action='store_true', help='is sd 2 model')
|
||||
|
||||
@@ -61,29 +64,68 @@ find_matches = False
|
||||
|
||||
print(f'Loading diffusers model')
|
||||
|
||||
ignore_ldm_begins_with = []
|
||||
|
||||
diffusers_file_path = file_path
|
||||
if args.ssd:
|
||||
diffusers_file_path = "segmind/SSD-1B"
|
||||
|
||||
diffusers_model_config = ModelConfig(
|
||||
name_or_path=diffusers_file_path,
|
||||
is_xl=args.sdxl,
|
||||
is_v2=args.sd2,
|
||||
is_ssd=args.ssd,
|
||||
dtype=dtype,
|
||||
)
|
||||
diffusers_sd = StableDiffusion(
|
||||
model_config=diffusers_model_config,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
diffusers_sd.load_model()
|
||||
# delete things we dont need
|
||||
del diffusers_sd.tokenizer
|
||||
flush()
|
||||
if args.refiner:
|
||||
diffusers_file_path = "stabilityai/stable-diffusion-xl-refiner-1.0"
|
||||
|
||||
if not args.refiner:
|
||||
|
||||
diffusers_model_config = ModelConfig(
|
||||
name_or_path=diffusers_file_path,
|
||||
is_xl=args.sdxl,
|
||||
is_v2=args.sd2,
|
||||
is_ssd=args.ssd,
|
||||
dtype=dtype,
|
||||
)
|
||||
diffusers_sd = StableDiffusion(
|
||||
model_config=diffusers_model_config,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
diffusers_sd.load_model()
|
||||
# delete things we dont need
|
||||
del diffusers_sd.tokenizer
|
||||
flush()
|
||||
|
||||
print(f'Loading ldm model')
|
||||
diffusers_state_dict = diffusers_sd.state_dict()
|
||||
else:
|
||||
# refiner wont work directly with stable diffusion
|
||||
# so we need to load the model and then load the state dict
|
||||
diffusers_pipeline = StableDiffusionXLPipeline.from_single_file(
|
||||
file_path,
|
||||
torch_dtype=torch.float16,
|
||||
use_safetensors=True,
|
||||
variant="fp16",
|
||||
).to(device)
|
||||
|
||||
SD_PREFIX_VAE = "vae"
|
||||
SD_PREFIX_UNET = "unet"
|
||||
SD_PREFIX_REFINER_UNET = "refiner_unet"
|
||||
SD_PREFIX_TEXT_ENCODER = "te"
|
||||
|
||||
SD_PREFIX_TEXT_ENCODER1 = "te0"
|
||||
SD_PREFIX_TEXT_ENCODER2 = "te1"
|
||||
|
||||
diffusers_state_dict = OrderedDict()
|
||||
for k, v in diffusers_pipeline.vae.state_dict().items():
|
||||
new_key = k if k.startswith(f"{SD_PREFIX_VAE}") else f"{SD_PREFIX_VAE}_{k}"
|
||||
diffusers_state_dict[new_key] = v
|
||||
for k, v in diffusers_pipeline.text_encoder_2.state_dict().items():
|
||||
new_key = k if k.startswith(f"{SD_PREFIX_TEXT_ENCODER2}_") else f"{SD_PREFIX_TEXT_ENCODER2}_{k}"
|
||||
diffusers_state_dict[new_key] = v
|
||||
for k, v in diffusers_pipeline.unet.state_dict().items():
|
||||
new_key = k if k.startswith(f"{SD_PREFIX_UNET}_") else f"{SD_PREFIX_UNET}_{k}"
|
||||
diffusers_state_dict[new_key] = v
|
||||
|
||||
# add ignore ones as we are only going to focus on unet and copy the rest
|
||||
# ignore_ldm_begins_with = ["conditioner.", "first_stage_model."]
|
||||
|
||||
print(f'Loading ldm model')
|
||||
diffusers_state_dict = diffusers_sd.state_dict()
|
||||
diffusers_dict_keys = list(diffusers_state_dict.keys())
|
||||
|
||||
ldm_state_dict = load_file(file_path)
|
||||
@@ -113,6 +155,12 @@ if args.sdxl or args.ssd:
|
||||
proj_pattern_weight = r"conditioner\.embedders\.1\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_weight"
|
||||
proj_pattern_bias = r"conditioner\.embedders\.1\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_bias"
|
||||
text_proj_layer = "conditioner.embedders.1.model.text_projection"
|
||||
if args.refiner:
|
||||
te_suffix = '1'
|
||||
ldm_res_block_prefix = "conditioner.embedders.0.model.transformer.resblocks"
|
||||
proj_pattern_weight = r"conditioner\.embedders\.0\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_weight"
|
||||
proj_pattern_bias = r"conditioner\.embedders\.0\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_bias"
|
||||
text_proj_layer = "conditioner.embedders.0.model.text_projection"
|
||||
if args.sd2:
|
||||
te_suffix = ''
|
||||
ldm_res_block_prefix = "cond_stage_model.model.transformer.resblocks"
|
||||
@@ -120,7 +168,7 @@ if args.sd2:
|
||||
proj_pattern_bias = r"cond_stage_model\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_bias"
|
||||
text_proj_layer = "cond_stage_model.model.text_projection"
|
||||
|
||||
if args.sdxl or args.sd2 or args.ssd:
|
||||
if args.sdxl or args.sd2 or args.ssd or args.refiner:
|
||||
if "conditioner.embedders.1.model.text_projection" in ldm_dict_keys:
|
||||
# d_model = int(checkpoint[prefix + "text_projection"].shape[0]))
|
||||
d_model = int(ldm_state_dict["conditioner.embedders.1.model.text_projection"].shape[0])
|
||||
@@ -297,6 +345,8 @@ if args.sdxl:
|
||||
name += '_sdxl'
|
||||
elif args.ssd:
|
||||
name += '_ssd'
|
||||
elif args.refiner:
|
||||
name += '_refiner'
|
||||
elif args.sd2:
|
||||
name += '_sd2'
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user