mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-23 15:59:32 +00:00
Merge branch 'main' into development
This commit is contained in:
@@ -709,61 +709,28 @@ class StableDiffusion:
|
||||
for i, encoder in enumerate(self.text_encoder):
|
||||
for k, v in encoder.state_dict().items():
|
||||
new_key = k if k.startswith(
|
||||
f"{SD_PREFIX_TEXT_ENCODER}{i}") else f"{SD_PREFIX_TEXT_ENCODER}{i}_{k}"
|
||||
f"{SD_PREFIX_TEXT_ENCODER}{i}_") else f"{SD_PREFIX_TEXT_ENCODER}{i}_{k}"
|
||||
state_dict[new_key] = v
|
||||
else:
|
||||
for k, v in self.text_encoder.state_dict().items():
|
||||
new_key = k if k.startswith(f"{SD_PREFIX_TEXT_ENCODER}") else f"{SD_PREFIX_TEXT_ENCODER}_{k}"
|
||||
new_key = k if k.startswith(f"{SD_PREFIX_TEXT_ENCODER}_") else f"{SD_PREFIX_TEXT_ENCODER}_{k}"
|
||||
state_dict[new_key] = v
|
||||
if unet:
|
||||
for k, v in self.unet.state_dict().items():
|
||||
new_key = k if k.startswith(f"{SD_PREFIX_UNET}") else f"{SD_PREFIX_UNET}_{k}"
|
||||
new_key = k if k.startswith(f"{SD_PREFIX_UNET}_") else f"{SD_PREFIX_UNET}_{k}"
|
||||
state_dict[new_key] = v
|
||||
return state_dict
|
||||
|
||||
def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None):
|
||||
state_dict = {}
|
||||
# prepare metadata
|
||||
meta = get_meta_for_safetensors(meta)
|
||||
|
||||
def update_sd(prefix, sd):
|
||||
for k, v in sd.items():
|
||||
key = prefix + k
|
||||
v = v.detach().clone()
|
||||
state_dict[key] = v.to("cpu", dtype=get_torch_dtype(save_dtype))
|
||||
# make sure there are not nan values
|
||||
if torch.isnan(state_dict[key]).any():
|
||||
raise ValueError(f"NaN value in state dict: {key}")
|
||||
|
||||
# todo see what logit scale is
|
||||
version_string = '1'
|
||||
if self.is_v2:
|
||||
version_string = '2'
|
||||
if self.is_xl:
|
||||
save_ldm_model_from_diffusers(
|
||||
sd=self,
|
||||
output_file=output_file,
|
||||
meta=meta,
|
||||
save_dtype=save_dtype,
|
||||
sd_version='sdxl',
|
||||
)
|
||||
|
||||
else:
|
||||
# Convert the UNet model
|
||||
unet_state_dict = convert_unet_state_dict_to_sd(self.is_v2, self.unet.state_dict())
|
||||
update_sd("model.diffusion_model.", unet_state_dict)
|
||||
|
||||
# Convert the text encoder model
|
||||
if self.is_v2:
|
||||
make_dummy = True
|
||||
text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(self.text_encoder.state_dict(), make_dummy)
|
||||
update_sd("cond_stage_model.model.", text_enc_dict)
|
||||
else:
|
||||
text_enc_dict = self.text_encoder.state_dict()
|
||||
update_sd("cond_stage_model.transformer.", text_enc_dict)
|
||||
|
||||
# Convert the VAE
|
||||
if self.vae is not None:
|
||||
vae_dict = model_util.convert_vae_state_dict(self.vae.state_dict())
|
||||
update_sd("first_stage_model.", vae_dict)
|
||||
|
||||
# make sure parent folder exists
|
||||
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||
save_file(state_dict, output_file, metadata=meta)
|
||||
version_string = 'sdxl'
|
||||
save_ldm_model_from_diffusers(
|
||||
sd=self,
|
||||
output_file=output_file,
|
||||
meta=meta,
|
||||
save_dtype=save_dtype,
|
||||
sd_version=version_string,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user