Added extensions and an example extension that merges models

This commit is contained in:
Jaret Burkett
2023-08-04 09:37:24 -06:00
parent b865ac8b24
commit 7e4e660663
14 changed files with 366 additions and 24 deletions

View File

@@ -8,7 +8,10 @@ from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import
from safetensors.torch import save_file
from tqdm import tqdm
from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \
convert_vae_state_dict
from toolkit.config_modules import ModelConfig, GenerateImageConfig
from toolkit.metadata import get_meta_for_safetensors
from toolkit.paths import REPOS_ROOT
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
@@ -161,6 +164,7 @@ class StableDiffusion:
scheduler_type='dpm',
device=self.device_torch,
load_safety_checker=False,
requires_safety_checker=False,
).to(self.device_torch)
pipe.register_to_config(requires_safety_checker=False)
text_encoder = pipe.text_encoder
@@ -468,17 +472,16 @@ class StableDiffusion:
)
def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None):
state_dict = {}
def update_sd(prefix, sd):
for k, v in sd.items():
key = prefix + k
v = v.detach().clone().to("cpu").to(get_torch_dtype(save_dtype))
state_dict[key] = v
# todo see what logit scale is
if self.is_xl:
state_dict = {}
def update_sd(prefix, sd):
for k, v in sd.items():
key = prefix + k
v = v.detach().clone().to("cpu").to(get_torch_dtype(save_dtype))
state_dict[key] = v
# Convert the UNet model
update_sd("model.diffusion_model.", self.unet.state_dict())
@@ -488,19 +491,25 @@ class StableDiffusion:
text_enc2_dict = convert_text_encoder_2_state_dict_to_sdxl(self.text_encoder[1].state_dict(), logit_scale)
update_sd("conditioner.embedders.1.model.", text_enc2_dict)
else:
# Convert the UNet model
unet_state_dict = convert_unet_state_dict_to_sd(self.is_v2, self.unet.state_dict())
update_sd("model.diffusion_model.", unet_state_dict)
# Convert the text encoder model
if self.is_v2:
make_dummy = True
text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(self.text_encoder.state_dict(), make_dummy)
update_sd("cond_stage_model.model.", text_enc_dict)
else:
text_enc_dict = self.text_encoder.state_dict()
update_sd("cond_stage_model.transformer.", text_enc_dict)
# Convert the VAE
if self.vae is not None:
vae_dict = model_util.convert_vae_state_dict(self.vae.state_dict())
update_sd("first_stage_model.", vae_dict)
# Put together new checkpoint
key_count = len(state_dict.keys())
new_ckpt = {"state_dict": state_dict}
if model_util.is_safetensors(output_file):
save_file(state_dict, output_file)
else:
torch.save(new_ckpt, output_file, meta)
return key_count
else:
raise NotImplementedError("sdv1.x, sdv2.x is not implemented yet")
# prepare metadata
meta = get_meta_for_safetensors(meta)
save_file(state_dict, output_file, metadata=meta)