mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Fixed weight mapping for refiner
This commit is contained in:
@@ -70,8 +70,10 @@ diffusers_file_path = file_path
|
|||||||
if args.ssd:
|
if args.ssd:
|
||||||
diffusers_file_path = "segmind/SSD-1B"
|
diffusers_file_path = "segmind/SSD-1B"
|
||||||
|
|
||||||
if args.refiner:
|
# if args.refiner:
|
||||||
diffusers_file_path = "stabilityai/stable-diffusion-xl-refiner-1.0"
|
# diffusers_file_path = "stabilityai/stable-diffusion-xl-refiner-1.0"
|
||||||
|
|
||||||
|
diffusers_file_path = file_path if len(args.file_1) == 1 else args.file_1[1]
|
||||||
|
|
||||||
if not args.refiner:
|
if not args.refiner:
|
||||||
|
|
||||||
@@ -98,11 +100,17 @@ else:
|
|||||||
# refiner wont work directly with stable diffusion
|
# refiner wont work directly with stable diffusion
|
||||||
# so we need to load the model and then load the state dict
|
# so we need to load the model and then load the state dict
|
||||||
diffusers_pipeline = StableDiffusionXLPipeline.from_single_file(
|
diffusers_pipeline = StableDiffusionXLPipeline.from_single_file(
|
||||||
file_path,
|
diffusers_file_path,
|
||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16,
|
||||||
use_safetensors=True,
|
use_safetensors=True,
|
||||||
variant="fp16",
|
variant="fp16",
|
||||||
).to(device)
|
).to(device)
|
||||||
|
# diffusers_pipeline = StableDiffusionXLPipeline.from_single_file(
|
||||||
|
# file_path,
|
||||||
|
# torch_dtype=torch.float16,
|
||||||
|
# use_safetensors=True,
|
||||||
|
# variant="fp16",
|
||||||
|
# ).to(device)
|
||||||
|
|
||||||
SD_PREFIX_VAE = "vae"
|
SD_PREFIX_VAE = "vae"
|
||||||
SD_PREFIX_UNET = "unet"
|
SD_PREFIX_UNET = "unet"
|
||||||
@@ -141,7 +149,7 @@ total_keys = len(ldm_dict_keys)
|
|||||||
matched_ldm_keys = []
|
matched_ldm_keys = []
|
||||||
matched_diffusers_keys = []
|
matched_diffusers_keys = []
|
||||||
|
|
||||||
error_margin = 1e-6
|
error_margin = 1e-8
|
||||||
|
|
||||||
tmp_merge_key = "TMP___MERGE"
|
tmp_merge_key = "TMP___MERGE"
|
||||||
|
|
||||||
@@ -172,6 +180,9 @@ if args.sdxl or args.sd2 or args.ssd or args.refiner:
|
|||||||
if "conditioner.embedders.1.model.text_projection" in ldm_dict_keys:
|
if "conditioner.embedders.1.model.text_projection" in ldm_dict_keys:
|
||||||
# d_model = int(checkpoint[prefix + "text_projection"].shape[0]))
|
# d_model = int(checkpoint[prefix + "text_projection"].shape[0]))
|
||||||
d_model = int(ldm_state_dict["conditioner.embedders.1.model.text_projection"].shape[0])
|
d_model = int(ldm_state_dict["conditioner.embedders.1.model.text_projection"].shape[0])
|
||||||
|
elif "conditioner.embedders.0.model.text_projection" in ldm_dict_keys:
|
||||||
|
# d_model = int(checkpoint[prefix + "text_projection"].shape[0]))
|
||||||
|
d_model = int(ldm_state_dict["conditioner.embedders.0.model.text_projection"].shape[0])
|
||||||
else:
|
else:
|
||||||
d_model = 1024
|
d_model = 1024
|
||||||
|
|
||||||
@@ -291,11 +302,11 @@ for ldm_key in ldm_dict_keys:
|
|||||||
diffusers_reduced_shape_tuple = get_reduced_shape(diffusers_shape_tuple)
|
diffusers_reduced_shape_tuple = get_reduced_shape(diffusers_shape_tuple)
|
||||||
|
|
||||||
# That was easy. Same key
|
# That was easy. Same key
|
||||||
if ldm_key == diffusers_key:
|
# if ldm_key == diffusers_key:
|
||||||
ldm_diffusers_keymap[ldm_key] = diffusers_key
|
# ldm_diffusers_keymap[ldm_key] = diffusers_key
|
||||||
matched_ldm_keys.append(ldm_key)
|
# matched_ldm_keys.append(ldm_key)
|
||||||
matched_diffusers_keys.append(diffusers_key)
|
# matched_diffusers_keys.append(diffusers_key)
|
||||||
break
|
# break
|
||||||
|
|
||||||
# if we already have this key mapped, skip it
|
# if we already have this key mapped, skip it
|
||||||
if diffusers_key in matched_diffusers_keys:
|
if diffusers_key in matched_diffusers_keys:
|
||||||
@@ -320,7 +331,7 @@ for ldm_key in ldm_dict_keys:
|
|||||||
did_reduce_diffusers = True
|
did_reduce_diffusers = True
|
||||||
|
|
||||||
# check to see if they match within a margin of error
|
# check to see if they match within a margin of error
|
||||||
mse = torch.nn.functional.mse_loss(ldm_weight, diffusers_weight)
|
mse = torch.nn.functional.mse_loss(ldm_weight.float(), diffusers_weight.float())
|
||||||
if mse < error_margin:
|
if mse < error_margin:
|
||||||
ldm_diffusers_keymap[ldm_key] = diffusers_key
|
ldm_diffusers_keymap[ldm_key] = diffusers_key
|
||||||
matched_ldm_keys.append(ldm_key)
|
matched_ldm_keys.append(ldm_key)
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
Binary file not shown.
@@ -4,7 +4,24 @@
|
|||||||
"shape": [],
|
"shape": [],
|
||||||
"min": 4.60546875,
|
"min": 4.60546875,
|
||||||
"max": 4.60546875
|
"max": 4.60546875
|
||||||
|
},
|
||||||
|
"conditioner.embedders.0.model.text_projection": {
|
||||||
|
"shape": [
|
||||||
|
1280,
|
||||||
|
1280
|
||||||
|
],
|
||||||
|
"min": -0.15966796875,
|
||||||
|
"max": 0.230712890625
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"diffusers": {}
|
"diffusers": {
|
||||||
|
"te1_text_projection.weight": {
|
||||||
|
"shape": [
|
||||||
|
1280,
|
||||||
|
1280
|
||||||
|
],
|
||||||
|
"min": -0.15966796875,
|
||||||
|
"max": 0.230712890625
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Reference in New Issue
Block a user