mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Fixed weight mapping for refiner
This commit is contained in:
@@ -70,8 +70,10 @@ diffusers_file_path = file_path
|
||||
if args.ssd:
|
||||
diffusers_file_path = "segmind/SSD-1B"
|
||||
|
||||
if args.refiner:
|
||||
diffusers_file_path = "stabilityai/stable-diffusion-xl-refiner-1.0"
|
||||
# if args.refiner:
|
||||
# diffusers_file_path = "stabilityai/stable-diffusion-xl-refiner-1.0"
|
||||
|
||||
diffusers_file_path = file_path if len(args.file_1) == 1 else args.file_1[1]
|
||||
|
||||
if not args.refiner:
|
||||
|
||||
@@ -98,11 +100,17 @@ else:
|
||||
# refiner wont work directly with stable diffusion
|
||||
# so we need to load the model and then load the state dict
|
||||
diffusers_pipeline = StableDiffusionXLPipeline.from_single_file(
|
||||
file_path,
|
||||
diffusers_file_path,
|
||||
torch_dtype=torch.float16,
|
||||
use_safetensors=True,
|
||||
variant="fp16",
|
||||
).to(device)
|
||||
# diffusers_pipeline = StableDiffusionXLPipeline.from_single_file(
|
||||
# file_path,
|
||||
# torch_dtype=torch.float16,
|
||||
# use_safetensors=True,
|
||||
# variant="fp16",
|
||||
# ).to(device)
|
||||
|
||||
SD_PREFIX_VAE = "vae"
|
||||
SD_PREFIX_UNET = "unet"
|
||||
@@ -141,7 +149,7 @@ total_keys = len(ldm_dict_keys)
|
||||
matched_ldm_keys = []
|
||||
matched_diffusers_keys = []
|
||||
|
||||
error_margin = 1e-6
|
||||
error_margin = 1e-8
|
||||
|
||||
tmp_merge_key = "TMP___MERGE"
|
||||
|
||||
@@ -172,6 +180,9 @@ if args.sdxl or args.sd2 or args.ssd or args.refiner:
|
||||
if "conditioner.embedders.1.model.text_projection" in ldm_dict_keys:
|
||||
# d_model = int(checkpoint[prefix + "text_projection"].shape[0]))
|
||||
d_model = int(ldm_state_dict["conditioner.embedders.1.model.text_projection"].shape[0])
|
||||
elif "conditioner.embedders.0.model.text_projection" in ldm_dict_keys:
|
||||
# d_model = int(checkpoint[prefix + "text_projection"].shape[0]))
|
||||
d_model = int(ldm_state_dict["conditioner.embedders.0.model.text_projection"].shape[0])
|
||||
else:
|
||||
d_model = 1024
|
||||
|
||||
@@ -291,11 +302,11 @@ for ldm_key in ldm_dict_keys:
|
||||
diffusers_reduced_shape_tuple = get_reduced_shape(diffusers_shape_tuple)
|
||||
|
||||
# That was easy. Same key
|
||||
if ldm_key == diffusers_key:
|
||||
ldm_diffusers_keymap[ldm_key] = diffusers_key
|
||||
matched_ldm_keys.append(ldm_key)
|
||||
matched_diffusers_keys.append(diffusers_key)
|
||||
break
|
||||
# if ldm_key == diffusers_key:
|
||||
# ldm_diffusers_keymap[ldm_key] = diffusers_key
|
||||
# matched_ldm_keys.append(ldm_key)
|
||||
# matched_diffusers_keys.append(diffusers_key)
|
||||
# break
|
||||
|
||||
# if we already have this key mapped, skip it
|
||||
if diffusers_key in matched_diffusers_keys:
|
||||
@@ -320,7 +331,7 @@ for ldm_key in ldm_dict_keys:
|
||||
did_reduce_diffusers = True
|
||||
|
||||
# check to see if they match within a margin of error
|
||||
mse = torch.nn.functional.mse_loss(ldm_weight, diffusers_weight)
|
||||
mse = torch.nn.functional.mse_loss(ldm_weight.float(), diffusers_weight.float())
|
||||
if mse < error_margin:
|
||||
ldm_diffusers_keymap[ldm_key] = diffusers_key
|
||||
matched_ldm_keys.append(ldm_key)
|
||||
|
||||
Reference in New Issue
Block a user