From 7157c316af0bfc1da39fea712d70ba55593e573f Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Wed, 23 Aug 2023 15:37:00 -0600 Subject: [PATCH] Added support for training lora, dreambooth, and fine tuning. Still need testing and docs --- extensions_built_in/sd_trainer/SDTrainer.py | 174 ++++++++++++++++++ .../__init__.py | 19 +- .../config/train.example.yaml | 0 .../TextualInversionTrainer.py | 152 --------------- jobs/process/BaseSDTrainProcess.py | 39 +++- toolkit/config_modules.py | 1 + toolkit/data_loader.py | 9 +- toolkit/stable_diffusion_model.py | 36 +++- 8 files changed, 265 insertions(+), 165 deletions(-) create mode 100644 extensions_built_in/sd_trainer/SDTrainer.py rename extensions_built_in/{textual_inversion_trainer => sd_trainer}/__init__.py (64%) rename extensions_built_in/{textual_inversion_trainer => sd_trainer}/config/train.example.yaml (100%) delete mode 100644 extensions_built_in/textual_inversion_trainer/TextualInversionTrainer.py diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py new file mode 100644 index 00000000..4f172b15 --- /dev/null +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -0,0 +1,174 @@ +from collections import OrderedDict +from torch.utils.data import DataLoader +from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds +from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork +from toolkit.train_tools import get_torch_dtype, apply_snr_weight +import gc +import torch +from jobs.process import BaseSDTrainProcess + + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + +class SDTrainer(BaseSDTrainProcess): + sd: StableDiffusion + data_loader: DataLoader = None + + def __init__(self, process_id: int, job, config: OrderedDict, **kwargs): + super().__init__(process_id, job, config, **kwargs) + pass + + def before_model_load(self): + pass + + def hook_before_train_loop(self): + self.sd.vae.eval() + self.sd.vae.to(self.device_torch) + + # textual inversion + if self.embedding is not None: + # keep original embeddings as reference + self.orig_embeds_params = self.sd.text_encoder.get_input_embeddings().weight.data.clone() + # set text encoder to train. Not sure if this is necessary but diffusers example did it + self.sd.text_encoder.train() + + def hook_train_loop(self, batch): + with torch.no_grad(): + imgs, prompts, dataset_config = batch + + # convert the 0 or 1 for is reg to a bool list + is_reg_list = dataset_config.get('is_reg', [0 for _ in range(imgs.shape[0])]) + if isinstance(is_reg_list, torch.Tensor): + is_reg_list = is_reg_list.numpy().tolist() + is_reg_list = [bool(x) for x in is_reg_list] + + conditioned_prompts = [] + + for prompt, is_reg in zip(prompts, is_reg_list): + + # make sure the embedding is in the prompts + if self.embedding is not None: + prompt = self.embedding.inject_embedding_to_prompt( + prompt, + expand_token=True, + add_if_not_present=True, + ) + + # make sure trigger is in the prompts if not a regularization run + if self.trigger_word is not None and not is_reg: + prompt = self.sd.inject_trigger_into_prompt( + prompt, + add_if_not_present=True, + ) + conditioned_prompts.append(prompt) + + batch_size = imgs.shape[0] + + dtype = get_torch_dtype(self.train_config.dtype) + imgs = imgs.to(self.device_torch, dtype=dtype) + latents = self.sd.encode_images(imgs) + + noise_scheduler = self.sd.noise_scheduler + optimizer = self.optimizer + lr_scheduler = self.lr_scheduler + + self.sd.noise_scheduler.set_timesteps( + self.train_config.max_denoising_steps, device=self.device_torch + ) + + timesteps = torch.randint(0, self.train_config.max_denoising_steps, (batch_size,), device=self.device_torch) + timesteps = timesteps.long() + + # get noise + noise = self.sd.get_latent_noise( + pixel_height=imgs.shape[2], + pixel_width=imgs.shape[3], + batch_size=batch_size, + noise_offset=self.train_config.noise_offset + ).to(self.device_torch, dtype=dtype) + + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # remove grads for these + noisy_latents.requires_grad = False + noise.requires_grad = False + + flush() + + self.optimizer.zero_grad() + + # text encoding + grad_on_text_encoder = False + if self.train_config.train_text_encoder: + grad_on_text_encoder = True + + if self.embedding: + grad_on_text_encoder = True + + # have a blank network so we can wrap it in a context and set multipliers without checking every time + if self.network is not None: + network = self.network + else: + network = BlankNetwork() + + # activate network if it exits + with network: + with torch.set_grad_enabled(grad_on_text_encoder): + embedding_list = [] + # embed the prompts + for prompt in conditioned_prompts: + embedding = self.sd.encode_prompt(prompt).to(self.device_torch, dtype=dtype) + embedding_list.append(embedding) + conditional_embeds = concat_prompt_embeds(embedding_list) + + noise_pred = self.sd.predict_noise( + latents=noisy_latents.to(self.device_torch, dtype=dtype), + conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype), + timestep=timesteps, + guidance_scale=1.0, + ) + + noise = noise.to(self.device_torch, dtype=dtype) + + if self.sd.prediction_type == 'v_prediction': + # v-parameterization training + target = noise_scheduler.get_velocity(noisy_latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + + if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001: + # add min_snr_gamma + loss = apply_snr_weight(loss, timesteps, noise_scheduler, self.train_config.min_snr_gamma) + + loss = loss.mean() + + # back propagate loss to free ram + loss.backward() + flush() + + # apply gradients + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() + + if self.embedding is not None: + # Let's make sure we don't update any embedding weights besides the newly added token + index_no_updates = torch.ones((len(self.sd.tokenizer),), dtype=torch.bool) + index_no_updates[ + min(self.embedding.placeholder_token_ids): max(self.embedding.placeholder_token_ids) + 1] = False + with torch.no_grad(): + self.sd.text_encoder.get_input_embeddings().weight[ + index_no_updates + ] = self.orig_embeds_params[index_no_updates] + + loss_dict = OrderedDict( + {'loss': loss.item()} + ) + + return loss_dict diff --git a/extensions_built_in/textual_inversion_trainer/__init__.py b/extensions_built_in/sd_trainer/__init__.py similarity index 64% rename from extensions_built_in/textual_inversion_trainer/__init__.py rename to extensions_built_in/sd_trainer/__init__.py index 7167178f..45aa841e 100644 --- a/extensions_built_in/textual_inversion_trainer/__init__.py +++ b/extensions_built_in/sd_trainer/__init__.py @@ -2,24 +2,29 @@ from toolkit.extension import Extension -# We make a subclass of Extension -class OffsetSliderTrainer(Extension): +# This is for generic training (LoRA, Dreambooth, FineTuning) +class SDTrainerExtension(Extension): # uid must be unique, it is how the extension is identified - uid = "textual_inversion_trainer" + uid = "sd_trainer" # name is the name of the extension for printing - name = "Textual Inversion Trainer" + name = "SD Trainer" # This is where your process class is loaded # keep your imports in here so they don't slow down the rest of the program @classmethod def get_process(cls): # import your process class here so it is only loaded when needed and return it - from .TextualInversionTrainer import TextualInversionTrainer - return TextualInversionTrainer + from .SDTrainer import SDTrainer + return SDTrainer + + +# for backwards compatability +class TextualInversionTrainer(SDTrainerExtension): + uid = "textual_inversion_trainer" AI_TOOLKIT_EXTENSIONS = [ # you can put a list of extensions here - OffsetSliderTrainer + SDTrainerExtension, TextualInversionTrainer ] diff --git a/extensions_built_in/textual_inversion_trainer/config/train.example.yaml b/extensions_built_in/sd_trainer/config/train.example.yaml similarity index 100% rename from extensions_built_in/textual_inversion_trainer/config/train.example.yaml rename to extensions_built_in/sd_trainer/config/train.example.yaml diff --git a/extensions_built_in/textual_inversion_trainer/TextualInversionTrainer.py b/extensions_built_in/textual_inversion_trainer/TextualInversionTrainer.py deleted file mode 100644 index 9eb6e364..00000000 --- a/extensions_built_in/textual_inversion_trainer/TextualInversionTrainer.py +++ /dev/null @@ -1,152 +0,0 @@ -import copy -import random -from collections import OrderedDict -import os -from contextlib import nullcontext -from typing import Optional, Union, List -from torch.utils.data import ConcatDataset, DataLoader - -from toolkit.config_modules import ReferenceDatasetConfig -from toolkit.data_loader import PairedImageDataset, ImageDataset -from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds -from toolkit.stable_diffusion_model import StableDiffusion, PromptEmbeds -from toolkit.train_tools import get_torch_dtype, apply_snr_weight, apply_noise_offset -import gc -from toolkit import train_tools -import torch -from jobs.process import BaseSDTrainProcess -import random -from toolkit.basic import value_map - - -def flush(): - torch.cuda.empty_cache() - gc.collect() - - -class TextualInversionTrainer(BaseSDTrainProcess): - sd: StableDiffusion - data_loader: DataLoader = None - - def __init__(self, process_id: int, job, config: OrderedDict, **kwargs): - super().__init__(process_id, job, config, **kwargs) - pass - - def before_model_load(self): - pass - - def hook_before_train_loop(self): - self.sd.vae.eval() - self.sd.vae.to(self.device_torch) - - # keep original embeddings as reference - self.orig_embeds_params = self.sd.text_encoder.get_input_embeddings().weight.data.clone() - # set text encoder to train. Not sure if this is necessary but diffusers example did it - self.sd.text_encoder.train() - pass - - def hook_train_loop(self, batch): - with torch.no_grad(): - imgs, prompts = batch - - # very loosely based on this. very loosely - # ref https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py - - # make sure the embedding is in the prompts - conditioned_prompts = [self.embedding.inject_embedding_to_prompt( - x, - expand_token=True, - add_if_not_present=True, - ) for x in prompts] - - batch_size = imgs.shape[0] - - dtype = get_torch_dtype(self.train_config.dtype) - imgs = imgs.to(self.device_torch, dtype=dtype) - latents = self.sd.encode_images(imgs) - - noise_scheduler = self.sd.noise_scheduler - optimizer = self.optimizer - lr_scheduler = self.lr_scheduler - - self.sd.noise_scheduler.set_timesteps( - self.train_config.max_denoising_steps, device=self.device_torch - ) - - timesteps = torch.randint(0, self.train_config.max_denoising_steps, (batch_size,), device=self.device_torch) - timesteps = timesteps.long() - - # get noise - noise = self.sd.get_latent_noise( - pixel_height=imgs.shape[2], - pixel_width=imgs.shape[3], - batch_size=batch_size, - noise_offset=self.train_config.noise_offset - ).to(self.device_torch, dtype=dtype) - - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - # remove grads for these - noisy_latents.requires_grad = False - noise.requires_grad = False - - flush() - - self.optimizer.zero_grad() - noisy_latents.requires_grad = False - - # text encoding - embedding_list = [] - # embed the prompts - for prompt in conditioned_prompts: - embedding = self.sd.encode_prompt(prompt).to(self.device_torch, dtype=dtype) - embedding_list.append(embedding) - conditional_embeds = concat_prompt_embeds(embedding_list) - - noise_pred = self.sd.predict_noise( - latents=noisy_latents.to(self.device_torch, dtype=dtype), - conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype), - timestep=timesteps, - guidance_scale=1.0, - ) - noise = noise.to(self.device_torch, dtype=dtype) - - if self.sd.prediction_type == 'v_prediction': - # v-parameterization training - target = noise_scheduler.get_velocity(noisy_latents, noise, timesteps) - else: - target = noise - - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - loss = loss.mean([1, 2, 3]) - - if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001: - # add min_snr_gamma - loss = apply_snr_weight(loss, timesteps, noise_scheduler, self.train_config.min_snr_gamma) - - loss = loss.mean() - - # back propagate loss to free ram - loss.backward() - flush() - - # apply gradients - optimizer.step() - optimizer.zero_grad() - lr_scheduler.step() - - # Let's make sure we don't update any embedding weights besides the newly added token - index_no_updates = torch.ones((len(self.sd.tokenizer),), dtype=torch.bool) - index_no_updates[ - min(self.embedding.placeholder_token_ids): max(self.embedding.placeholder_token_ids) + 1] = False - with torch.no_grad(): - self.sd.text_encoder.get_input_embeddings().weight[ - index_no_updates - ] = self.orig_embeds_params[index_no_updates] - - loss_dict = OrderedDict( - {'loss': loss.item()} - ) - - return loss_dict - # end hook_train_loop diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 838bd8dc..3865fdb1 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -61,11 +61,23 @@ class BaseSDTrainProcess(BaseTrainProcess): self.optimizer = None self.lr_scheduler = None self.data_loader: Union[DataLoader, None] = None + self.data_loader_reg: Union[DataLoader, None] = None + self.trigger_word = self.get_conf('trigger_word', None) raw_datasets = self.get_conf('datasets', None) self.datasets = None + self.datasets_reg = None if raw_datasets is not None and len(raw_datasets) > 0: - self.datasets = [DatasetConfig(**d) for d in raw_datasets] + for raw_dataset in raw_datasets: + dataset = DatasetConfig(**raw_dataset) + if dataset.is_reg: + if self.datasets_reg is None: + self.datasets_reg = [] + self.datasets_reg.append(dataset) + else: + if self.datasets is None: + self.datasets = [] + self.datasets.append(dataset) self.embed_config = None embedding_raw = self.get_conf('embedding', None) @@ -112,6 +124,10 @@ class BaseSDTrainProcess(BaseTrainProcess): prompt = self.embedding.inject_embedding_to_prompt( prompt, ) + if self.trigger_word is not None: + prompt = self.sd.inject_trigger_into_prompt( + prompt, self.trigger_word + ) gen_img_config_list.append(GenerateImageConfig( prompt=prompt, # it will autoparse the prompt @@ -275,6 +291,8 @@ class BaseSDTrainProcess(BaseTrainProcess): # load datasets if passed in the root process if self.datasets is not None: self.data_loader = get_dataloader_from_datasets(self.datasets, self.train_config.batch_size) + if self.datasets_reg is not None: + self.data_loader_reg = get_dataloader_from_datasets(self.datasets_reg, self.train_config.batch_size) ### HOOK ### self.hook_before_model_load() @@ -433,14 +451,29 @@ class BaseSDTrainProcess(BaseTrainProcess): dataloader = None dataloader_iterator = None + if self.data_loader_reg is not None: + dataloader_reg = self.data_loader_reg + dataloader_iterator_reg = iter(dataloader_reg) + else: + dataloader_reg = None + dataloader_iterator_reg = None + # self.step_num = 0 for step in range(self.step_num, self.train_config.steps): - if dataloader is not None: + # if is even step and we have a reg dataset, use that + # todo improve this logic to send one of each through if we can buckets and batch size might be an issue + if step % 2 == 0 and dataloader_reg is not None: + try: + batch = next(dataloader_iterator_reg) + except StopIteration: + # hit the end of an epoch, reset + dataloader_iterator_reg = iter(dataloader_reg) + batch = next(dataloader_iterator_reg) + elif dataloader is not None: try: batch = next(dataloader_iterator) except StopIteration: # hit the end of an epoch, reset - # todo, should we do something else here? like blow up balloons? dataloader_iterator = iter(dataloader) batch = next(dataloader_iterator) else: diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 46e22fda..85fa2823 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -168,6 +168,7 @@ class DatasetConfig: self.resolution: int = kwargs.get('resolution', 512) self.scale: float = kwargs.get('scale', 1.0) self.buckets: bool = kwargs.get('buckets', False) + self.is_reg: bool = kwargs.get('is_reg', False) class GenerateImageConfig: diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index 2ccf6890..eabbe8d2 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -356,11 +356,16 @@ class AiToolkitDataset(Dataset, CaptionMixin): img = self.transform(img) + # todo convert it all + dataset_config_dict = { + "is_reg": 1 if self.dataset_config.is_reg else 0, + } + if self.caption_type is not None: prompt = self.get_caption_item(index) - return img, prompt + return img, prompt, dataset_config_dict else: - return img + return img, dataset_config_dict def get_dataloader_from_datasets(dataset_options, batch_size=1): diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index e4667ee6..65a3bcbd 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -515,7 +515,8 @@ class StableDiffusion: elif ts_bs * 2 == latent_model_input.shape[0]: timestep = torch.cat([timestep] * 2) else: - raise ValueError(f"Batch size of latents {latent_model_input.shape[0]} must be the same or half the batch size of timesteps {timestep.shape[0]}") + raise ValueError( + f"Batch size of latents {latent_model_input.shape[0]} must be the same or half the batch size of timesteps {timestep.shape[0]}") # predict the noise residual noise_pred = self.unet( @@ -659,6 +660,39 @@ class StableDiffusion: raise ValueError(f"Unknown weight name: {name}") + def inject_trigger_into_prompt(self, prompt, trigger=None, to_replace_list=None, add_if_not_present=True): + if trigger is None: + return prompt + output_prompt = prompt + default_replacements = ["[name]", "[trigger]"] + + replace_with = trigger + if to_replace_list is None: + to_replace_list = default_replacements + else: + to_replace_list += default_replacements + + # remove duplicates + to_replace_list = list(set(to_replace_list)) + + # replace them all + for to_replace in to_replace_list: + # replace it + output_prompt = output_prompt.replace(to_replace, replace_with) + + # see how many times replace_with is in the prompt + num_instances = prompt.count(replace_with) + + if num_instances == 0 and add_if_not_present: + # add it to the beginning of the prompt + output_prompt = replace_with + " " + output_prompt + + if num_instances > 1: + print( + f"Warning: {trigger} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.") + + return output_prompt + def state_dict(self, vae=True, text_encoder=True, unet=True): state_dict = OrderedDict() if vae: