Fixed weight mapping for refiner

This commit is contained in:
Jaret Burkett
2023-11-06 07:37:47 -07:00
parent 93ea955d7c
commit a8b3b8b8da
4 changed files with 702 additions and 675 deletions

View File

@@ -70,8 +70,10 @@ diffusers_file_path = file_path
if args.ssd:
diffusers_file_path = "segmind/SSD-1B"
if args.refiner:
diffusers_file_path = "stabilityai/stable-diffusion-xl-refiner-1.0"
# if args.refiner:
# diffusers_file_path = "stabilityai/stable-diffusion-xl-refiner-1.0"
diffusers_file_path = file_path if len(args.file_1) == 1 else args.file_1[1]
if not args.refiner:
@@ -98,11 +100,17 @@ else:
# refiner wont work directly with stable diffusion
# so we need to load the model and then load the state dict
diffusers_pipeline = StableDiffusionXLPipeline.from_single_file(
file_path,
diffusers_file_path,
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16",
).to(device)
# diffusers_pipeline = StableDiffusionXLPipeline.from_single_file(
# file_path,
# torch_dtype=torch.float16,
# use_safetensors=True,
# variant="fp16",
# ).to(device)
SD_PREFIX_VAE = "vae"
SD_PREFIX_UNET = "unet"
@@ -141,7 +149,7 @@ total_keys = len(ldm_dict_keys)
matched_ldm_keys = []
matched_diffusers_keys = []
error_margin = 1e-6
error_margin = 1e-8
tmp_merge_key = "TMP___MERGE"
@@ -172,6 +180,9 @@ if args.sdxl or args.sd2 or args.ssd or args.refiner:
if "conditioner.embedders.1.model.text_projection" in ldm_dict_keys:
# d_model = int(checkpoint[prefix + "text_projection"].shape[0]))
d_model = int(ldm_state_dict["conditioner.embedders.1.model.text_projection"].shape[0])
elif "conditioner.embedders.0.model.text_projection" in ldm_dict_keys:
# d_model = int(checkpoint[prefix + "text_projection"].shape[0]))
d_model = int(ldm_state_dict["conditioner.embedders.0.model.text_projection"].shape[0])
else:
d_model = 1024
@@ -291,11 +302,11 @@ for ldm_key in ldm_dict_keys:
diffusers_reduced_shape_tuple = get_reduced_shape(diffusers_shape_tuple)
# That was easy. Same key
if ldm_key == diffusers_key:
ldm_diffusers_keymap[ldm_key] = diffusers_key
matched_ldm_keys.append(ldm_key)
matched_diffusers_keys.append(diffusers_key)
break
# if ldm_key == diffusers_key:
# ldm_diffusers_keymap[ldm_key] = diffusers_key
# matched_ldm_keys.append(ldm_key)
# matched_diffusers_keys.append(diffusers_key)
# break
# if we already have this key mapped, skip it
if diffusers_key in matched_diffusers_keys:
@@ -320,7 +331,7 @@ for ldm_key in ldm_dict_keys:
did_reduce_diffusers = True
# check to see if they match within a margin of error
mse = torch.nn.functional.mse_loss(ldm_weight, diffusers_weight)
mse = torch.nn.functional.mse_loss(ldm_weight.float(), diffusers_weight.float())
if mse < error_margin:
ldm_diffusers_keymap[ldm_key] = diffusers_key
matched_ldm_keys.append(ldm_key)