Fixed weight mapping for refiner

This commit is contained in:
Jaret Burkett
2023-11-06 07:37:47 -07:00
parent 93ea955d7c
commit a8b3b8b8da
4 changed files with 702 additions and 675 deletions

View File

@@ -70,8 +70,10 @@ diffusers_file_path = file_path
if args.ssd: if args.ssd:
diffusers_file_path = "segmind/SSD-1B" diffusers_file_path = "segmind/SSD-1B"
if args.refiner: # if args.refiner:
diffusers_file_path = "stabilityai/stable-diffusion-xl-refiner-1.0" # diffusers_file_path = "stabilityai/stable-diffusion-xl-refiner-1.0"
diffusers_file_path = file_path if len(args.file_1) == 1 else args.file_1[1]
if not args.refiner: if not args.refiner:
@@ -98,11 +100,17 @@ else:
# refiner wont work directly with stable diffusion # refiner wont work directly with stable diffusion
# so we need to load the model and then load the state dict # so we need to load the model and then load the state dict
diffusers_pipeline = StableDiffusionXLPipeline.from_single_file( diffusers_pipeline = StableDiffusionXLPipeline.from_single_file(
file_path, diffusers_file_path,
torch_dtype=torch.float16, torch_dtype=torch.float16,
use_safetensors=True, use_safetensors=True,
variant="fp16", variant="fp16",
).to(device) ).to(device)
# diffusers_pipeline = StableDiffusionXLPipeline.from_single_file(
# file_path,
# torch_dtype=torch.float16,
# use_safetensors=True,
# variant="fp16",
# ).to(device)
SD_PREFIX_VAE = "vae" SD_PREFIX_VAE = "vae"
SD_PREFIX_UNET = "unet" SD_PREFIX_UNET = "unet"
@@ -141,7 +149,7 @@ total_keys = len(ldm_dict_keys)
matched_ldm_keys = [] matched_ldm_keys = []
matched_diffusers_keys = [] matched_diffusers_keys = []
error_margin = 1e-6 error_margin = 1e-8
tmp_merge_key = "TMP___MERGE" tmp_merge_key = "TMP___MERGE"
@@ -172,6 +180,9 @@ if args.sdxl or args.sd2 or args.ssd or args.refiner:
if "conditioner.embedders.1.model.text_projection" in ldm_dict_keys: if "conditioner.embedders.1.model.text_projection" in ldm_dict_keys:
# d_model = int(checkpoint[prefix + "text_projection"].shape[0])) # d_model = int(checkpoint[prefix + "text_projection"].shape[0]))
d_model = int(ldm_state_dict["conditioner.embedders.1.model.text_projection"].shape[0]) d_model = int(ldm_state_dict["conditioner.embedders.1.model.text_projection"].shape[0])
elif "conditioner.embedders.0.model.text_projection" in ldm_dict_keys:
# d_model = int(checkpoint[prefix + "text_projection"].shape[0]))
d_model = int(ldm_state_dict["conditioner.embedders.0.model.text_projection"].shape[0])
else: else:
d_model = 1024 d_model = 1024
@@ -291,11 +302,11 @@ for ldm_key in ldm_dict_keys:
diffusers_reduced_shape_tuple = get_reduced_shape(diffusers_shape_tuple) diffusers_reduced_shape_tuple = get_reduced_shape(diffusers_shape_tuple)
# That was easy. Same key # That was easy. Same key
if ldm_key == diffusers_key: # if ldm_key == diffusers_key:
ldm_diffusers_keymap[ldm_key] = diffusers_key # ldm_diffusers_keymap[ldm_key] = diffusers_key
matched_ldm_keys.append(ldm_key) # matched_ldm_keys.append(ldm_key)
matched_diffusers_keys.append(diffusers_key) # matched_diffusers_keys.append(diffusers_key)
break # break
# if we already have this key mapped, skip it # if we already have this key mapped, skip it
if diffusers_key in matched_diffusers_keys: if diffusers_key in matched_diffusers_keys:
@@ -320,7 +331,7 @@ for ldm_key in ldm_dict_keys:
did_reduce_diffusers = True did_reduce_diffusers = True
# check to see if they match within a margin of error # check to see if they match within a margin of error
mse = torch.nn.functional.mse_loss(ldm_weight, diffusers_weight) mse = torch.nn.functional.mse_loss(ldm_weight.float(), diffusers_weight.float())
if mse < error_margin: if mse < error_margin:
ldm_diffusers_keymap[ldm_key] = diffusers_key ldm_diffusers_keymap[ldm_key] = diffusers_key
matched_ldm_keys.append(ldm_key) matched_ldm_keys.append(ldm_key)

File diff suppressed because it is too large Load Diff

View File

@@ -4,7 +4,24 @@
"shape": [], "shape": [],
"min": 4.60546875, "min": 4.60546875,
"max": 4.60546875 "max": 4.60546875
},
"conditioner.embedders.0.model.text_projection": {
"shape": [
1280,
1280
],
"min": -0.15966796875,
"max": 0.230712890625
} }
}, },
"diffusers": {} "diffusers": {
"te1_text_projection.weight": {
"shape": [
1280,
1280
],
"min": -0.15966796875,
"max": 0.230712890625
}
}
} }