From 379992d89ee085bcdc0d9c062ea0fcd01df76e3e Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sat, 12 Aug 2023 05:59:50 -0600 Subject: [PATCH] Various bug fixes and improvements --- .../ImageReferenceSliderTrainerProcess.py | 75 +++------ .../config/train.example.yaml | 24 +-- toolkit/lora_special.py | 20 ++- toolkit/stable_diffusion_model.py | 149 +++++++++++++++--- toolkit/train_tools.py | 5 +- 5 files changed, 180 insertions(+), 93 deletions(-) diff --git a/extensions_built_in/image_reference_slider_trainer/ImageReferenceSliderTrainerProcess.py b/extensions_built_in/image_reference_slider_trainer/ImageReferenceSliderTrainerProcess.py index a4816bec..127f56cc 100644 --- a/extensions_built_in/image_reference_slider_trainer/ImageReferenceSliderTrainerProcess.py +++ b/extensions_built_in/image_reference_slider_trainer/ImageReferenceSliderTrainerProcess.py @@ -2,11 +2,12 @@ import copy import random from collections import OrderedDict import os +from contextlib import nullcontext from typing import Optional, Union, List from torch.utils.data import ConcatDataset, DataLoader from toolkit.data_loader import PairedImageDataset from toolkit.prompt_utils import concat_prompt_embeds -from toolkit.stable_diffusion_model import StableDiffusion +from toolkit.stable_diffusion_model import StableDiffusion, PromptEmbeds from toolkit.train_tools import get_torch_dtype import gc from toolkit import train_tools @@ -80,34 +81,16 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess): imgs, prompts = batch dtype = get_torch_dtype(self.train_config.dtype) imgs: torch.Tensor = imgs.to(self.device_torch, dtype=dtype) - # split batched images in half so left is negative and right is positive negative_images, positive_images = torch.chunk(imgs, 2, dim=3) + positive_latents = self.sd.encode_images(positive_images) + negative_latents = self.sd.encode_images(negative_images) + height = positive_images.shape[2] width = positive_images.shape[3] batch_size = positive_images.shape[0] - # encode the images - positive_latents = self.sd.vae.encode(positive_images).latent_dist.sample() - positive_latents = positive_latents * 0.18215 - negative_latents = self.sd.vae.encode(negative_images).latent_dist.sample() - negative_latents = negative_latents * 0.18215 - - embedding_list = [] - negative_embedding_list = [] - # embed the prompts - for prompt in prompts: - embedding = self.sd.encode_prompt(prompt).to(self.device_torch, dtype=dtype) - embedding_list.append(embedding) - # just empty for now - # todo cache this? - negative_embed = self.sd.encode_prompt('').to(self.device_torch, dtype=dtype) - negative_embedding_list.append(negative_embed) - - conditional_embeds = concat_prompt_embeds(embedding_list) - unconditional_embeds = concat_prompt_embeds(negative_embedding_list) - if self.train_config.gradient_checkpointing: # may get disabled elsewhere self.sd.unet.enable_gradient_checkpointing() @@ -115,26 +98,12 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess): noise_scheduler = self.sd.noise_scheduler optimizer = self.optimizer lr_scheduler = self.lr_scheduler - loss_function = torch.nn.MSELoss() - def get_noise_pred(neg, pos, gs, cts, dn): - return self.sd.predict_noise( - latents=dn, - text_embeddings=train_tools.concat_prompt_embeddings( - neg, # negative prompt - pos, # positive prompt - self.train_config.batch_size, - ), - timestep=cts, - guidance_scale=gs, - ) - - with torch.no_grad(): self.sd.noise_scheduler.set_timesteps( self.train_config.max_denoising_steps, device=self.device_torch ) - timesteps = torch.randint(0, self.train_config.max_denoising_steps, (batch_size,), device=self.device_torch) + timesteps = torch.randint(0, self.train_config.max_denoising_steps, (1,), device=self.device_torch) timesteps = timesteps.long() # get noise @@ -147,6 +116,7 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess): if do_mirror_loss: # mirror the noise + # torch shape is [batch, channels, height, width] noise_negative = torch.flip(noise_positive.clone(), dims=[3]) else: noise_negative = noise_positive.clone() @@ -159,8 +129,6 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess): noisy_latents = torch.cat([noisy_positive_latents, noisy_negative_latents], dim=0) noise = torch.cat([noise_positive, noise_negative], dim=0) timesteps = torch.cat([timesteps, timesteps], dim=0) - conditional_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds]) - unconditional_embeds = concat_prompt_embeds([unconditional_embeds, unconditional_embeds]) network_multiplier = [1.0, -1.0] flush() @@ -170,22 +138,31 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess): loss_mirror_float = None self.optimizer.zero_grad() + noisy_latents.requires_grad = False + + # if training text encoder enable grads, else do context of no grad + with torch.set_grad_enabled(self.train_config.train_text_encoder): + # text encoding + embedding_list = [] + # embed the prompts + for prompt in prompts: + embedding = self.sd.encode_prompt(prompt).to(self.device_torch, dtype=dtype) + embedding_list.append(embedding) + conditional_embeds = concat_prompt_embeds(embedding_list) + conditional_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds]) + with self.network: assert self.network.is_active - loss_list = [] - # do positive first self.network.multiplier = network_multiplier - noise_pred = get_noise_pred( - unconditional_embeds, - conditional_embeds, - 1, - timesteps, - noisy_latents + noise_pred = self.sd.predict_noise( + latents=noisy_latents, + conditional_embeddings=conditional_embeds, + timestep=timesteps, ) - if self.sd.is_v2: # check is vpred, don't want to track it down right now + if self.sd.prediction_type == 'v_prediction': # v-parameterization training target = noise_scheduler.get_velocity(noisy_latents, noise, timesteps) else: @@ -199,7 +176,6 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess): loss = loss.mean() loss_slide_float = loss.item() - if do_mirror_loss: noise_pred_pos, noise_pred_neg = torch.chunk(noise_pred, 2, dim=0) # mirror the negative @@ -221,7 +197,6 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess): optimizer.step() lr_scheduler.step() - # reset network self.network.multiplier = 1.0 diff --git a/extensions_built_in/image_reference_slider_trainer/config/train.example.yaml b/extensions_built_in/image_reference_slider_trainer/config/train.example.yaml index 301790f3..52e3d5d6 100644 --- a/extensions_built_in/image_reference_slider_trainer/config/train.example.yaml +++ b/extensions_built_in/image_reference_slider_trainer/config/train.example.yaml @@ -9,17 +9,21 @@ config: # for tensorboard logging log_dir: "/home/jaret/Dev/.tensorboard" network: - type: "lierla" # lierla is traditional LoRA that works everywhere, only linear layers - rank: 16 - alpha: 8 + type: "lora" + linear: 64 + linear_alpha: 32 + conv: 32 + conv_alpha: 16 train: noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a" - steps: 1000 - lr: 5e-5 + steps: 5000 + lr: 1e-4 train_unet: true gradient_checkpointing: true - train_text_encoder: false - optimizer: "lion8bit" + train_text_encoder: true + optimizer: "adamw" + optimizer_params: + weight_decay: 1e-2 lr_scheduler: "constant" max_denoising_steps: 1000 batch_size: 1 @@ -36,11 +40,11 @@ config: is_v_pred: false # for v-prediction models (most v2 models) save: dtype: float16 # precision to save - save_every: 100 # save every this many steps + save_every: 1000 # save every this many steps max_step_saves_to_keep: 2 # only affects step counts sample: sampler: "ddpm" # must match train.noise_scheduler - sample_every: 20 # sample every this many steps + sample_every: 100 # sample every this many steps width: 512 height: 512 prompts: @@ -81,6 +85,8 @@ config: - 512 slider_pair_folder: "/mnt/Datasets/stable-diffusion/slider_reference/subject_turner" target_class: "photo of a person" +# additional_losses: +# - "mirror" meta: diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index 35c4223a..79ece460 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -97,21 +97,25 @@ class LoRAModule(torch.nn.Module): if len(self.multiplier) == 0: # single item, just return it return self.multiplier[0] + elif len(self.multiplier) == batch_size: + # not doing CFG + multiplier_tensor = torch.tensor(self.multiplier).to(lora_up.device, dtype=lora_up.dtype) else: + # we have a list of multipliers, so we need to get the multiplier for this batch multiplier_tensor = torch.tensor(self.multiplier * 2).to(lora_up.device, dtype=lora_up.dtype) # should be 1 for if total batch size was 1 num_interleaves = (batch_size // 2) // len(self.multiplier) multiplier_tensor = multiplier_tensor.repeat_interleave(num_interleaves) - # match lora_up rank - if len(lora_up.size()) == 2: - multiplier_tensor = multiplier_tensor.view(-1, 1) - elif len(lora_up.size()) == 3: - multiplier_tensor = multiplier_tensor.view(-1, 1, 1) - elif len(lora_up.size()) == 4: - multiplier_tensor = multiplier_tensor.view(-1, 1, 1, 1) - return multiplier_tensor + # match lora_up rank + if len(lora_up.size()) == 2: + multiplier_tensor = multiplier_tensor.view(-1, 1) + elif len(lora_up.size()) == 3: + multiplier_tensor = multiplier_tensor.view(-1, 1, 1) + elif len(lora_up.size()) == 4: + multiplier_tensor = multiplier_tensor.view(-1, 1, 1, 1) + return multiplier_tensor else: return self.multiplier diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index f06f2929..65085f67 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -7,9 +7,11 @@ import os from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg from safetensors.torch import save_file from tqdm import tqdm +from torchvision.transforms import Resize from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \ convert_vae_state_dict +from toolkit import train_tools from toolkit.config_modules import ModelConfig, GenerateImageConfig from toolkit.metadata import get_meta_for_safetensors from toolkit.paths import REPOS_ROOT @@ -180,6 +182,7 @@ class StableDiffusion: device=self.device_torch, load_safety_checker=False, requires_safety_checker=False, + safety_checker=False ).to(self.device_torch) else: pipe = pipln.from_single_file( @@ -189,7 +192,9 @@ class StableDiffusion: device=self.device_torch, load_safety_checker=False, requires_safety_checker=False, + safety_checker=False ).to(self.device_torch) + pipe.register_to_config(requires_safety_checker=False) text_encoder = pipe.text_encoder text_encoder.to(self.device_torch, dtype=dtype) @@ -379,28 +384,60 @@ class StableDiffusion: dynamic_crops=False, # look into this dtype=dtype, ).to(self.device_torch, dtype=dtype) - return train_util.concat_embeddings( - prompt_ids, prompt_ids, bs - ) + return prompt_ids else: return None def predict_noise( self, - latents: torch.FloatTensor, - text_embeddings: PromptEmbeds, - timestep: int, + latents: torch.Tensor, + text_embeddings: Union[PromptEmbeds, None] = None, + timestep: Union[int, torch.Tensor] = 1, guidance_scale=7.5, - guidance_rescale=0, # 0.7 + guidance_rescale=0, # 0.7 sdxl add_time_ids=None, + conditional_embeddings: Union[PromptEmbeds, None] = None, + unconditional_embeddings: Union[PromptEmbeds, None] = None, **kwargs, ): + # get the embeddings + if text_embeddings is None and conditional_embeddings is None: + raise ValueError("Either text_embeddings or conditional_embeddings must be specified") + if text_embeddings is None and unconditional_embeddings is not None: + text_embeddings = train_tools.concat_prompt_embeddings( + unconditional_embeddings, # negative embedding + conditional_embeddings, # positive embedding + latents.shape[0], # batch size + ) + elif text_embeddings is None and conditional_embeddings is not None: + # not doing cfg + text_embeddings = conditional_embeddings + + # CFG is comparing neg and positive, if we have concatenated embeddings + # then we are doing it, otherwise we are not and takes half the time. + do_classifier_free_guidance = True + + # check if batch size of embeddings matches batch size of latents + if latents.shape[0] == text_embeddings.text_embeds.shape[0]: + do_classifier_free_guidance = False + elif latents.shape[0] * 2 != text_embeddings.text_embeds.shape[0]: + raise ValueError("Batch size of latents must be the same or half the batch size of text embeddings") if self.is_xl: if add_time_ids is None: add_time_ids = self.get_time_ids_from_latents(latents) - latent_model_input = torch.cat([latents] * 2) + if do_classifier_free_guidance: + # todo check this with larget batches + train_util.concat_embeddings( + add_time_ids, add_time_ids, 1 + ) + else: + # concat to fit batch size + add_time_ids = torch.cat([add_time_ids] * latents.shape[0]) + + if do_classifier_free_guidance: + latent_model_input = torch.cat([latents] * 2) latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep) @@ -417,20 +454,24 @@ class StableDiffusion: added_cond_kwargs=added_cond_kwargs, ).sample - # perform guidance - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * ( - noise_pred_text - noise_pred_uncond - ) + if do_classifier_free_guidance: + # perform guidance + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) - # https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775 - if guidance_rescale > 0.0: - # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + # https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775 + if guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) else: - # if we are doing classifier free guidance, need to double up - latent_model_input = torch.cat([latents] * 2) + if do_classifier_free_guidance: + # if we are doing classifier free guidance, need to double up + latent_model_input = torch.cat([latents] * 2) + else: + latent_model_input = latents latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep) @@ -441,10 +482,12 @@ class StableDiffusion: encoder_hidden_states=text_embeddings.text_embeds, ).sample - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * ( - noise_pred_text - noise_pred_uncond - ) + if do_classifier_free_guidance: + # perform guidance + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) return noise_pred @@ -495,14 +538,68 @@ class StableDiffusion: ) ) + def encode_images( + self, + image_list: List[torch.Tensor], + device=None, + dtype=None + ): + if device is None: + device = self.device + if dtype is None: + dtype = self.torch_dtype + + latent_list = [] + # Move to vae to device if on cpu + if self.vae.device == 'cpu': + self.vae.to(self.device) + # move to device and dtype + image_list = [image.to(self.device, dtype=self.torch_dtype) for image in image_list] + + # resize images if not divisible by 8 + for i in range(len(image_list)): + image = image_list[i] + if image.shape[1] % 8 != 0 or image.shape[2] % 8 != 0: + image_list[i] = Resize((image.shape[1] // 8 * 8, image.shape[2] // 8 * 8))(image) + + images = torch.stack(image_list) + latents = self.vae.encode(images).latent_dist.sample() + latents = latents * 0.18215 + latents = latents.to(device, dtype=dtype) + + return latents + + def encode_image_prompt_pairs( + self, + prompt_list: List[str], + image_list: List[torch.Tensor], + device=None, + dtype=None + ): + # todo check image types and expand and rescale as needed + # device and dtype are for outputs + if device is None: + device = self.device + if dtype is None: + dtype = self.torch_dtype + + embedding_list = [] + latent_list = [] + # embed the prompts + for prompt in prompt_list: + embedding = self.encode_prompt(prompt).to(self.device_torch, dtype=dtype) + embedding_list.append(embedding) + + return embedding_list, latent_list + def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None): state_dict = {} def update_sd(prefix, sd): for k, v in sd.items(): key = prefix + k - v = v.detach().clone().to("cpu").to(get_torch_dtype(save_dtype)) - state_dict[key] = v + v = v.detach().clone() + state_dict[key] = v.to("cpu", dtype=get_torch_dtype(save_dtype)) # todo see what logit scale is if self.is_xl: @@ -536,4 +633,6 @@ class StableDiffusion: # prepare metadata meta = get_meta_for_safetensors(meta) + # make sure parent folder exists + os.makedirs(os.path.dirname(output_file), exist_ok=True) save_file(state_dict, output_file, metadata=meta) diff --git a/toolkit/train_tools.py b/toolkit/train_tools.py index 0e18cbc3..4ad723d8 100644 --- a/toolkit/train_tools.py +++ b/toolkit/train_tools.py @@ -34,13 +34,16 @@ SCHEDLER_SCHEDULE = "scaled_linear" def get_torch_dtype(dtype_str): + # if it is a torch dtype, return it + if isinstance(dtype_str, torch.dtype): + return dtype_str if dtype_str == "float" or dtype_str == "fp32" or dtype_str == "single" or dtype_str == "float32": return torch.float if dtype_str == "fp16" or dtype_str == "half" or dtype_str == "float16": return torch.float16 if dtype_str == "bf16" or dtype_str == "bfloat16": return torch.bfloat16 - return None + return dtype_str def replace_filewords_prompt(prompt, args: argparse.Namespace):