diff --git a/extensions_built_in/dataset_tools/SuperTagger.py b/extensions_built_in/dataset_tools/SuperTagger.py index b444216d..6eb3c70e 100644 --- a/extensions_built_in/dataset_tools/SuperTagger.py +++ b/extensions_built_in/dataset_tools/SuperTagger.py @@ -12,7 +12,7 @@ from .tools.dataset_tools_config_modules import RAW_DIR, TRAIN_DIR, Step, ImgInf from .tools.fuyu_utils import FuyuImageProcessor from .tools.image_tools import load_image, ImageProcessor, resize_to_max from .tools.llava_utils import LLaVAImageProcessor -from .tools.caption import default_long_prompt, default_short_prompt +from .tools.caption import default_long_prompt, default_short_prompt, default_replacements from jobs.process import BaseExtensionProcess from .tools.sync_tools import get_img_paths @@ -39,6 +39,8 @@ class SuperTagger(BaseExtensionProcess): self.caption_prompt = config.get('caption_prompt', default_long_prompt) self.caption_short_prompt = config.get('caption_short_prompt', default_short_prompt) self.force_reprocess_img = config.get('force_reprocess_img', False) + self.caption_replacements = config.get('caption_replacements', default_replacements) + self.caption_short_replacements = config.get('caption_short_replacements', default_replacements) self.master_dataset_dict = OrderedDict() self.dataset_master_config_file = config.get('dataset_master_config_file', None) if parent_dir is not None and len(self.dataset_paths) == 0: @@ -118,7 +120,8 @@ class SuperTagger(BaseExtensionProcess): img_info.caption = self.image_processor.generate_caption( image=caption_image, - prompt=self.caption_prompt + prompt=self.caption_prompt, + replacements=self.caption_replacements ) img_info.mark_step_complete(step) elif step == 'caption_short': @@ -134,7 +137,8 @@ class SuperTagger(BaseExtensionProcess): self.image_processor.load_model() img_info.caption_short = self.image_processor.generate_caption( image=caption_image, - prompt=self.caption_short_prompt + prompt=self.caption_short_prompt, + replacements=self.caption_short_replacements ) img_info.mark_step_complete(step) elif step == 'contrast_stretch': diff --git a/extensions_built_in/dataset_tools/tools/caption.py b/extensions_built_in/dataset_tools/tools/caption.py index 28daaee0..370786a8 100644 --- a/extensions_built_in/dataset_tools/tools/caption.py +++ b/extensions_built_in/dataset_tools/tools/caption.py @@ -33,7 +33,13 @@ def clean_caption(cap, replacements=None): cap = " ".join(cap.split()) for replacement in replacements: - cap = cap.replace(replacement[0], replacement[1]) + if replacement[0].startswith('*'): + # we are removing all text if it starts with this and the rest matches + search_text = replacement[0][1:] + if cap.startswith(search_text): + cap = "" + else: + cap = cap.replace(replacement[0].lower(), replacement[1].lower()) cap_list = cap.split(",") # trim whitespace diff --git a/extensions_built_in/dataset_tools/tools/llava_utils.py b/extensions_built_in/dataset_tools/tools/llava_utils.py index c66e1c5e..9ba38d66 100644 --- a/extensions_built_in/dataset_tools/tools/llava_utils.py +++ b/extensions_built_in/dataset_tools/tools/llava_utils.py @@ -77,7 +77,7 @@ class LLaVAImageProcessor: output_ids = self.model.generate( input_ids, images=image_tensor, do_sample=True, temperature=0.1, max_new_tokens=max_new_tokens, use_cache=True, stopping_criteria=[stopping_criteria], - top_p=0.9 + top_p=0.8 ) outputs = self.tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() conv.messages[-1][-1] = outputs diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index f3f750bd..27346c4c 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -6,7 +6,8 @@ from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO from toolkit.ip_adapter import IPAdapter from toolkit.prompt_utils import PromptEmbeds from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork -from toolkit.train_tools import get_torch_dtype, apply_snr_weight +from toolkit.train_tools import get_torch_dtype, apply_snr_weight, add_all_snr_to_noise_scheduler, \ + apply_learnable_snr_gos, LearnableSNRGamma import gc import torch from jobs.process import BaseSDTrainProcess @@ -59,6 +60,9 @@ class SDTrainer(BaseSDTrainProcess): self.sd.vae.to('cpu') flush() + self.sd.noise_scheduler.set_timesteps(1000) + add_all_snr_to_noise_scheduler(self.sd.noise_scheduler, self.device_torch) + # you can expand these in a child class to make customization easier def calculate_loss( self, @@ -145,7 +149,9 @@ class SDTrainer(BaseSDTrainProcess): loss = loss.mean([1, 2, 3]) - + if self.train_config.learnable_snr_gos: + # add snr_gamma + loss = apply_learnable_snr_gos(loss, timesteps, self.snr_gos) if self.train_config.snr_gamma is not None and self.train_config.snr_gamma > 0.000001 and not ignore_snr: # add snr_gamma loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.snr_gamma, fixed=True) @@ -315,14 +321,20 @@ class SDTrainer(BaseSDTrainProcess): # activate network if it exits - # make the batch splits + prompts_1 = conditioned_prompts + prompts_2 = None + if self.train_config.short_and_long_captions_encoder_split and self.sd.is_xl: + prompts_1 = batch.get_caption_short_list() + prompts_2 = conditioned_prompts + + # make the batch splits if self.train_config.single_item_batching: batch_size = noisy_latents.shape[0] # chunk/split everything noisy_latents_list = torch.chunk(noisy_latents, batch_size, dim=0) noise_list = torch.chunk(noise, batch_size, dim=0) timesteps_list = torch.chunk(timesteps, batch_size, dim=0) - conditioned_prompts_list = [[prompt] for prompt in conditioned_prompts] + conditioned_prompts_list = [[prompt] for prompt in prompts_1] if imgs is not None: imgs_list = torch.chunk(imgs, batch_size, dim=0) else: @@ -332,32 +344,44 @@ class SDTrainer(BaseSDTrainProcess): else: adapter_images_list = [None for _ in range(batch_size)] mask_multiplier_list = torch.chunk(mask_multiplier, batch_size, dim=0) + if prompts_2 is None: + prompt_2_list = [None for _ in range(batch_size)] + else: + prompt_2_list = [[prompt] for prompt in prompts_2] else: # but it all in an array noisy_latents_list = [noisy_latents] noise_list = [noise] timesteps_list = [timesteps] - conditioned_prompts_list = [conditioned_prompts] + conditioned_prompts_list = [prompts_1] imgs_list = [imgs] adapter_images_list = [adapter_images] mask_multiplier_list = [mask_multiplier] + if prompts_2 is None: + prompt_2_list = [None] + else: + prompt_2_list = [prompts_2] - for noisy_latents, noise, timesteps, conditioned_prompts, imgs, adapter_images, mask_multiplier in zip( + + + for noisy_latents, noise, timesteps, conditioned_prompts, imgs, adapter_images, mask_multiplier, prompt_2 in zip( noisy_latents_list, noise_list, timesteps_list, conditioned_prompts_list, imgs_list, adapter_images_list, - mask_multiplier_list + mask_multiplier_list, + prompt_2_list ): with network: with self.timer('encode_prompt'): if grad_on_text_encoder: with torch.set_grad_enabled(True): - conditional_embeds = self.sd.encode_prompt(conditioned_prompts, long_prompts=True).to( + conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=True).to( + # conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=False).to( self.device_torch, dtype=dtype) else: @@ -368,7 +392,8 @@ class SDTrainer(BaseSDTrainProcess): te.eval() else: self.sd.text_encoder.eval() - conditional_embeds = self.sd.encode_prompt(conditioned_prompts, long_prompts=True).to( + conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=True).to( + # conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=False).to( self.device_torch, dtype=dtype) diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 2f1c9176..9bf4c54a 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -1,6 +1,7 @@ import copy import glob import inspect +import json from collections import OrderedDict import os from typing import Union, List @@ -36,7 +37,7 @@ from toolkit.stable_diffusion_model import StableDiffusion from jobs.process import BaseTrainProcess from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_base_model_info_to_meta -from toolkit.train_tools import get_torch_dtype +from toolkit.train_tools import get_torch_dtype, LearnableSNRGamma import gc from tqdm import tqdm @@ -158,6 +159,7 @@ class BaseSDTrainProcess(BaseTrainProcess): self.named_lora = False if self.embed_config is not None or is_training_adapter: self.named_lora = True + self.snr_gos: Union[LearnableSNRGamma, None] = None def post_process_generate_image_config_list(self, generate_image_config_list: List[GenerateImageConfig]): # override in subclass @@ -370,6 +372,17 @@ class BaseSDTrainProcess(BaseTrainProcess): get_torch_dtype(self.save_config.dtype) ) + # save learnable params as json if we have thim + if self.snr_gos: + json_data = { + 'offset': self.snr_gos.offset.item(), + 'scale': self.snr_gos.scale.item(), + 'gamma': self.snr_gos.gamma.item(), + } + path_to_save = file_path = os.path.join(self.save_root, 'learnable_snr.json') + with open(path_to_save, 'w') as f: + json.dump(json_data, f, indent=4) + self.print(f"Saved to {file_path}") self.clean_up_saves() self.post_save_hook(file_path) @@ -789,6 +802,19 @@ class BaseSDTrainProcess(BaseTrainProcess): vae = vae.to(torch.device('cpu'), dtype=dtype) vae.requires_grad_(False) vae.eval() + if self.train_config.learnable_snr_gos: + self.snr_gos = LearnableSNRGamma( + self.sd.noise_scheduler, device=self.device_torch + ) + # check to see if previous settings exist + path_to_load = os.path.join(self.save_root, 'learnable_snr.json') + if os.path.exists(path_to_load): + with open(path_to_load, 'r') as f: + json_data = json.load(f) + self.snr_gos.offset.data = torch.tensor(json_data['offset'], device=self.device_torch) + self.snr_gos.scale.data = torch.tensor(json_data['scale'], device=self.device_torch) + self.snr_gos.gamma.data = torch.tensor(json_data['gamma'], device=self.device_torch) + flush() ### HOOk ### diff --git a/requirements.txt b/requirements.txt index efbccc4f..150b3dfb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ torch torchvision safetensors diffusers==0.21.3 -git+https://github.com/huggingface/transformers.git@master +git+https://github.com/huggingface/transformers.git lycoris-lora==1.8.3 flatten_json pyyaml diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 267a406a..9fb51b69 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -169,6 +169,9 @@ class TrainConfig: self.train_text_encoder = kwargs.get('train_text_encoder', True) self.min_snr_gamma = kwargs.get('min_snr_gamma', None) self.snr_gamma = kwargs.get('snr_gamma', None) + # trains a gamma, offset, and scale to adjust loss to adapt to timestep differentials + # this should balance the learning rate across all timesteps over time + self.learnable_snr_gos = kwargs.get('learnable_snr_gos', False) self.noise_offset = kwargs.get('noise_offset', 0.0) self.skip_first_sample = kwargs.get('skip_first_sample', False) self.gradient_checkpointing = kwargs.get('gradient_checkpointing', True) @@ -190,6 +193,8 @@ class TrainConfig: # Double up every image and run it through with both short and long captions. The idea # is that the network will learn how to generate good images with both short and long captions self.short_and_long_captions = kwargs.get('short_and_long_captions', False) + # if above is NOT true, this will make it so the long caption foes to te2 and the short caption goes to te1 for sdxl only + self.short_and_long_captions_encoder_split = kwargs.get('short_and_long_captions_encoder_split', False) # basically gradient accumulation but we run just 1 item through the network # and accumulate gradients. This can be used as basic gradient accumulation but is very helpful diff --git a/toolkit/optimizer.py b/toolkit/optimizer.py index bcb15761..2332b165 100644 --- a/toolkit/optimizer.py +++ b/toolkit/optimizer.py @@ -46,6 +46,8 @@ def get_optimizer( if lower_type == "adam8bit": return bitsandbytes.optim.Adam8bit(params, lr=learning_rate, **optimizer_params) + elif lower_type == "adamw8bit": + return bitsandbytes.optim.AdamW8bit(params, lr=learning_rate, **optimizer_params) elif lower_type == "lion8bit": return bitsandbytes.optim.Lion8bit(params, lr=learning_rate, **optimizer_params) else: diff --git a/toolkit/train_tools.py b/toolkit/train_tools.py index fe7a9964..edf1ee7d 100644 --- a/toolkit/train_tools.py +++ b/toolkit/train_tools.py @@ -683,6 +683,68 @@ def get_all_snr(noise_scheduler, device): all_snr.requires_grad = False return all_snr.to(device) +class LearnableSNRGamma: + """ + This is a trainer for learnable snr gamma + It will adapt to the dataset and attempt to adjust the snr multiplier to balance the loss over the timesteps + """ + def __init__(self, noise_scheduler: Union['DDPMScheduler'], device='cuda'): + self.device = device + self.noise_scheduler: Union['DDPMScheduler'] = noise_scheduler + self.offset = torch.nn.Parameter(torch.tensor(1.0, dtype=torch.float32, device=device)) + self.scale = torch.nn.Parameter(torch.tensor(0.001, dtype=torch.float32, device=device)) + self.gamma = torch.nn.Parameter(torch.tensor(1.0, dtype=torch.float32, device=device)) + self.optimizer = torch.optim.AdamW([self.offset, self.gamma, self.scale], lr=0.1) + self.buffer = [] + self.max_buffer_size = 100 + + def forward(self, loss, timesteps): + # do a our train loop for lsnr here and return our values detached + loss = loss.detach() + with torch.no_grad(): + loss_chunks = torch.chunk(loss, loss.shape[0], dim=0) + for loss_chunk in loss_chunks: + self.buffer.append(loss_chunk.mean().detach()) + if len(self.buffer) > self.max_buffer_size: + self.buffer.pop(0) + all_snr = get_all_snr(self.noise_scheduler, loss.device) + snr: torch.Tensor = torch.stack([all_snr[t] for t in timesteps]).detach().float().to(loss.device) + base_snrs = snr.clone().detach() + snr.requires_grad = True + snr = snr * self.scale + self.offset + + gamma_over_snr = torch.div(torch.ones_like(snr) * self.gamma, snr) + snr_weight = torch.abs(gamma_over_snr).float().to(loss.device) # directly using gamma over snr + snr_adjusted_loss = loss * snr_weight + with torch.no_grad(): + target = torch.mean(torch.stack(self.buffer)).detach() + + # local_loss = torch.mean(torch.abs(snr_adjusted_loss - target)) + squared_differences = (snr_adjusted_loss - target) ** 2 + local_loss = torch.mean(squared_differences) + local_loss.backward() + self.optimizer.step() + self.optimizer.zero_grad() + + return base_snrs, self.gamma.detach(), self.offset.detach(), self.scale.detach() + + +def apply_learnable_snr_gos( + loss, + timesteps, + learnable_snr_trainer:LearnableSNRGamma +): + + snr, gamma, offset, scale = learnable_snr_trainer.forward(loss, timesteps) + + snr = snr * scale + offset + + gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr) + snr_weight = torch.abs(gamma_over_snr).float().to(loss.device) # directly using gamma over snr + snr_adjusted_loss = loss * snr_weight + + return snr_adjusted_loss + def apply_snr_weight( loss, @@ -700,5 +762,6 @@ def apply_snr_weight( snr_weight = gamma_over_snr.float().to(loss.device) # directly using gamma over snr else: snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float().to(loss.device) - loss = loss * snr_weight - return loss + snr_adjusted_loss = loss * snr_weight + + return snr_adjusted_loss