Merge branch 'main' into development

This commit is contained in:
Jaret Burkett
2023-08-28 16:21:51 -06:00
9 changed files with 5120 additions and 150 deletions

View File

@@ -709,61 +709,28 @@ class StableDiffusion:
for i, encoder in enumerate(self.text_encoder):
for k, v in encoder.state_dict().items():
new_key = k if k.startswith(
f"{SD_PREFIX_TEXT_ENCODER}{i}") else f"{SD_PREFIX_TEXT_ENCODER}{i}_{k}"
f"{SD_PREFIX_TEXT_ENCODER}{i}_") else f"{SD_PREFIX_TEXT_ENCODER}{i}_{k}"
state_dict[new_key] = v
else:
for k, v in self.text_encoder.state_dict().items():
new_key = k if k.startswith(f"{SD_PREFIX_TEXT_ENCODER}") else f"{SD_PREFIX_TEXT_ENCODER}_{k}"
new_key = k if k.startswith(f"{SD_PREFIX_TEXT_ENCODER}_") else f"{SD_PREFIX_TEXT_ENCODER}_{k}"
state_dict[new_key] = v
if unet:
for k, v in self.unet.state_dict().items():
new_key = k if k.startswith(f"{SD_PREFIX_UNET}") else f"{SD_PREFIX_UNET}_{k}"
new_key = k if k.startswith(f"{SD_PREFIX_UNET}_") else f"{SD_PREFIX_UNET}_{k}"
state_dict[new_key] = v
return state_dict
def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None):
state_dict = {}
# prepare metadata
meta = get_meta_for_safetensors(meta)
def update_sd(prefix, sd):
for k, v in sd.items():
key = prefix + k
v = v.detach().clone()
state_dict[key] = v.to("cpu", dtype=get_torch_dtype(save_dtype))
# make sure there are not nan values
if torch.isnan(state_dict[key]).any():
raise ValueError(f"NaN value in state dict: {key}")
# todo see what logit scale is
version_string = '1'
if self.is_v2:
version_string = '2'
if self.is_xl:
save_ldm_model_from_diffusers(
sd=self,
output_file=output_file,
meta=meta,
save_dtype=save_dtype,
sd_version='sdxl',
)
else:
# Convert the UNet model
unet_state_dict = convert_unet_state_dict_to_sd(self.is_v2, self.unet.state_dict())
update_sd("model.diffusion_model.", unet_state_dict)
# Convert the text encoder model
if self.is_v2:
make_dummy = True
text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(self.text_encoder.state_dict(), make_dummy)
update_sd("cond_stage_model.model.", text_enc_dict)
else:
text_enc_dict = self.text_encoder.state_dict()
update_sd("cond_stage_model.transformer.", text_enc_dict)
# Convert the VAE
if self.vae is not None:
vae_dict = model_util.convert_vae_state_dict(self.vae.state_dict())
update_sd("first_stage_model.", vae_dict)
# make sure parent folder exists
os.makedirs(os.path.dirname(output_file), exist_ok=True)
save_file(state_dict, output_file, metadata=meta)
version_string = 'sdxl'
save_ldm_model_from_diffusers(
sd=self,
output_file=output_file,
meta=meta,
save_dtype=save_dtype,
sd_version=version_string,
)