mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-24 06:13:56 +00:00
Various bug fixes and improvements
This commit is contained in:
@@ -7,9 +7,11 @@ import os
|
||||
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
|
||||
from safetensors.torch import save_file
|
||||
from tqdm import tqdm
|
||||
from torchvision.transforms import Resize
|
||||
|
||||
from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \
|
||||
convert_vae_state_dict
|
||||
from toolkit import train_tools
|
||||
from toolkit.config_modules import ModelConfig, GenerateImageConfig
|
||||
from toolkit.metadata import get_meta_for_safetensors
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
@@ -180,6 +182,7 @@ class StableDiffusion:
|
||||
device=self.device_torch,
|
||||
load_safety_checker=False,
|
||||
requires_safety_checker=False,
|
||||
safety_checker=False
|
||||
).to(self.device_torch)
|
||||
else:
|
||||
pipe = pipln.from_single_file(
|
||||
@@ -189,7 +192,9 @@ class StableDiffusion:
|
||||
device=self.device_torch,
|
||||
load_safety_checker=False,
|
||||
requires_safety_checker=False,
|
||||
safety_checker=False
|
||||
).to(self.device_torch)
|
||||
|
||||
pipe.register_to_config(requires_safety_checker=False)
|
||||
text_encoder = pipe.text_encoder
|
||||
text_encoder.to(self.device_torch, dtype=dtype)
|
||||
@@ -379,28 +384,60 @@ class StableDiffusion:
|
||||
dynamic_crops=False, # look into this
|
||||
dtype=dtype,
|
||||
).to(self.device_torch, dtype=dtype)
|
||||
return train_util.concat_embeddings(
|
||||
prompt_ids, prompt_ids, bs
|
||||
)
|
||||
return prompt_ids
|
||||
else:
|
||||
return None
|
||||
|
||||
def predict_noise(
|
||||
self,
|
||||
latents: torch.FloatTensor,
|
||||
text_embeddings: PromptEmbeds,
|
||||
timestep: int,
|
||||
latents: torch.Tensor,
|
||||
text_embeddings: Union[PromptEmbeds, None] = None,
|
||||
timestep: Union[int, torch.Tensor] = 1,
|
||||
guidance_scale=7.5,
|
||||
guidance_rescale=0, # 0.7
|
||||
guidance_rescale=0, # 0.7 sdxl
|
||||
add_time_ids=None,
|
||||
conditional_embeddings: Union[PromptEmbeds, None] = None,
|
||||
unconditional_embeddings: Union[PromptEmbeds, None] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# get the embeddings
|
||||
if text_embeddings is None and conditional_embeddings is None:
|
||||
raise ValueError("Either text_embeddings or conditional_embeddings must be specified")
|
||||
if text_embeddings is None and unconditional_embeddings is not None:
|
||||
text_embeddings = train_tools.concat_prompt_embeddings(
|
||||
unconditional_embeddings, # negative embedding
|
||||
conditional_embeddings, # positive embedding
|
||||
latents.shape[0], # batch size
|
||||
)
|
||||
elif text_embeddings is None and conditional_embeddings is not None:
|
||||
# not doing cfg
|
||||
text_embeddings = conditional_embeddings
|
||||
|
||||
# CFG is comparing neg and positive, if we have concatenated embeddings
|
||||
# then we are doing it, otherwise we are not and takes half the time.
|
||||
do_classifier_free_guidance = True
|
||||
|
||||
# check if batch size of embeddings matches batch size of latents
|
||||
if latents.shape[0] == text_embeddings.text_embeds.shape[0]:
|
||||
do_classifier_free_guidance = False
|
||||
elif latents.shape[0] * 2 != text_embeddings.text_embeds.shape[0]:
|
||||
raise ValueError("Batch size of latents must be the same or half the batch size of text embeddings")
|
||||
|
||||
if self.is_xl:
|
||||
if add_time_ids is None:
|
||||
add_time_ids = self.get_time_ids_from_latents(latents)
|
||||
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
if do_classifier_free_guidance:
|
||||
# todo check this with larget batches
|
||||
train_util.concat_embeddings(
|
||||
add_time_ids, add_time_ids, 1
|
||||
)
|
||||
else:
|
||||
# concat to fit batch size
|
||||
add_time_ids = torch.cat([add_time_ids] * latents.shape[0])
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
|
||||
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep)
|
||||
|
||||
@@ -417,20 +454,24 @@ class StableDiffusion:
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
).sample
|
||||
|
||||
# perform guidance
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
if do_classifier_free_guidance:
|
||||
# perform guidance
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
|
||||
# https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775
|
||||
if guidance_rescale > 0.0:
|
||||
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||
# https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775
|
||||
if guidance_rescale > 0.0:
|
||||
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||
|
||||
else:
|
||||
# if we are doing classifier free guidance, need to double up
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
if do_classifier_free_guidance:
|
||||
# if we are doing classifier free guidance, need to double up
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
else:
|
||||
latent_model_input = latents
|
||||
|
||||
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep)
|
||||
|
||||
@@ -441,10 +482,12 @@ class StableDiffusion:
|
||||
encoder_hidden_states=text_embeddings.text_embeds,
|
||||
).sample
|
||||
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
if do_classifier_free_guidance:
|
||||
# perform guidance
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
|
||||
return noise_pred
|
||||
|
||||
@@ -495,14 +538,68 @@ class StableDiffusion:
|
||||
)
|
||||
)
|
||||
|
||||
def encode_images(
|
||||
self,
|
||||
image_list: List[torch.Tensor],
|
||||
device=None,
|
||||
dtype=None
|
||||
):
|
||||
if device is None:
|
||||
device = self.device
|
||||
if dtype is None:
|
||||
dtype = self.torch_dtype
|
||||
|
||||
latent_list = []
|
||||
# Move to vae to device if on cpu
|
||||
if self.vae.device == 'cpu':
|
||||
self.vae.to(self.device)
|
||||
# move to device and dtype
|
||||
image_list = [image.to(self.device, dtype=self.torch_dtype) for image in image_list]
|
||||
|
||||
# resize images if not divisible by 8
|
||||
for i in range(len(image_list)):
|
||||
image = image_list[i]
|
||||
if image.shape[1] % 8 != 0 or image.shape[2] % 8 != 0:
|
||||
image_list[i] = Resize((image.shape[1] // 8 * 8, image.shape[2] // 8 * 8))(image)
|
||||
|
||||
images = torch.stack(image_list)
|
||||
latents = self.vae.encode(images).latent_dist.sample()
|
||||
latents = latents * 0.18215
|
||||
latents = latents.to(device, dtype=dtype)
|
||||
|
||||
return latents
|
||||
|
||||
def encode_image_prompt_pairs(
|
||||
self,
|
||||
prompt_list: List[str],
|
||||
image_list: List[torch.Tensor],
|
||||
device=None,
|
||||
dtype=None
|
||||
):
|
||||
# todo check image types and expand and rescale as needed
|
||||
# device and dtype are for outputs
|
||||
if device is None:
|
||||
device = self.device
|
||||
if dtype is None:
|
||||
dtype = self.torch_dtype
|
||||
|
||||
embedding_list = []
|
||||
latent_list = []
|
||||
# embed the prompts
|
||||
for prompt in prompt_list:
|
||||
embedding = self.encode_prompt(prompt).to(self.device_torch, dtype=dtype)
|
||||
embedding_list.append(embedding)
|
||||
|
||||
return embedding_list, latent_list
|
||||
|
||||
def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None):
|
||||
state_dict = {}
|
||||
|
||||
def update_sd(prefix, sd):
|
||||
for k, v in sd.items():
|
||||
key = prefix + k
|
||||
v = v.detach().clone().to("cpu").to(get_torch_dtype(save_dtype))
|
||||
state_dict[key] = v
|
||||
v = v.detach().clone()
|
||||
state_dict[key] = v.to("cpu", dtype=get_torch_dtype(save_dtype))
|
||||
|
||||
# todo see what logit scale is
|
||||
if self.is_xl:
|
||||
@@ -536,4 +633,6 @@ class StableDiffusion:
|
||||
|
||||
# prepare metadata
|
||||
meta = get_meta_for_safetensors(meta)
|
||||
# make sure parent folder exists
|
||||
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||
save_file(state_dict, output_file, metadata=meta)
|
||||
|
||||
Reference in New Issue
Block a user