Added some split prompting started code, adamw8bit, replacements improving, learnable snr gos. A lot of good stuff.

This commit is contained in:
Jaret Burkett
2023-11-01 06:52:21 -06:00
parent 436a09430e
commit a899ec91c8
9 changed files with 149 additions and 18 deletions

View File

@@ -12,7 +12,7 @@ from .tools.dataset_tools_config_modules import RAW_DIR, TRAIN_DIR, Step, ImgInf
from .tools.fuyu_utils import FuyuImageProcessor from .tools.fuyu_utils import FuyuImageProcessor
from .tools.image_tools import load_image, ImageProcessor, resize_to_max from .tools.image_tools import load_image, ImageProcessor, resize_to_max
from .tools.llava_utils import LLaVAImageProcessor from .tools.llava_utils import LLaVAImageProcessor
from .tools.caption import default_long_prompt, default_short_prompt from .tools.caption import default_long_prompt, default_short_prompt, default_replacements
from jobs.process import BaseExtensionProcess from jobs.process import BaseExtensionProcess
from .tools.sync_tools import get_img_paths from .tools.sync_tools import get_img_paths
@@ -39,6 +39,8 @@ class SuperTagger(BaseExtensionProcess):
self.caption_prompt = config.get('caption_prompt', default_long_prompt) self.caption_prompt = config.get('caption_prompt', default_long_prompt)
self.caption_short_prompt = config.get('caption_short_prompt', default_short_prompt) self.caption_short_prompt = config.get('caption_short_prompt', default_short_prompt)
self.force_reprocess_img = config.get('force_reprocess_img', False) self.force_reprocess_img = config.get('force_reprocess_img', False)
self.caption_replacements = config.get('caption_replacements', default_replacements)
self.caption_short_replacements = config.get('caption_short_replacements', default_replacements)
self.master_dataset_dict = OrderedDict() self.master_dataset_dict = OrderedDict()
self.dataset_master_config_file = config.get('dataset_master_config_file', None) self.dataset_master_config_file = config.get('dataset_master_config_file', None)
if parent_dir is not None and len(self.dataset_paths) == 0: if parent_dir is not None and len(self.dataset_paths) == 0:
@@ -118,7 +120,8 @@ class SuperTagger(BaseExtensionProcess):
img_info.caption = self.image_processor.generate_caption( img_info.caption = self.image_processor.generate_caption(
image=caption_image, image=caption_image,
prompt=self.caption_prompt prompt=self.caption_prompt,
replacements=self.caption_replacements
) )
img_info.mark_step_complete(step) img_info.mark_step_complete(step)
elif step == 'caption_short': elif step == 'caption_short':
@@ -134,7 +137,8 @@ class SuperTagger(BaseExtensionProcess):
self.image_processor.load_model() self.image_processor.load_model()
img_info.caption_short = self.image_processor.generate_caption( img_info.caption_short = self.image_processor.generate_caption(
image=caption_image, image=caption_image,
prompt=self.caption_short_prompt prompt=self.caption_short_prompt,
replacements=self.caption_short_replacements
) )
img_info.mark_step_complete(step) img_info.mark_step_complete(step)
elif step == 'contrast_stretch': elif step == 'contrast_stretch':

View File

@@ -33,7 +33,13 @@ def clean_caption(cap, replacements=None):
cap = " ".join(cap.split()) cap = " ".join(cap.split())
for replacement in replacements: for replacement in replacements:
cap = cap.replace(replacement[0], replacement[1]) if replacement[0].startswith('*'):
# we are removing all text if it starts with this and the rest matches
search_text = replacement[0][1:]
if cap.startswith(search_text):
cap = ""
else:
cap = cap.replace(replacement[0].lower(), replacement[1].lower())
cap_list = cap.split(",") cap_list = cap.split(",")
# trim whitespace # trim whitespace

View File

@@ -77,7 +77,7 @@ class LLaVAImageProcessor:
output_ids = self.model.generate( output_ids = self.model.generate(
input_ids, images=image_tensor, do_sample=True, temperature=0.1, input_ids, images=image_tensor, do_sample=True, temperature=0.1,
max_new_tokens=max_new_tokens, use_cache=True, stopping_criteria=[stopping_criteria], max_new_tokens=max_new_tokens, use_cache=True, stopping_criteria=[stopping_criteria],
top_p=0.9 top_p=0.8
) )
outputs = self.tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() outputs = self.tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
conv.messages[-1][-1] = outputs conv.messages[-1][-1] = outputs

View File

@@ -6,7 +6,8 @@ from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
from toolkit.ip_adapter import IPAdapter from toolkit.ip_adapter import IPAdapter
from toolkit.prompt_utils import PromptEmbeds from toolkit.prompt_utils import PromptEmbeds
from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork
from toolkit.train_tools import get_torch_dtype, apply_snr_weight from toolkit.train_tools import get_torch_dtype, apply_snr_weight, add_all_snr_to_noise_scheduler, \
apply_learnable_snr_gos, LearnableSNRGamma
import gc import gc
import torch import torch
from jobs.process import BaseSDTrainProcess from jobs.process import BaseSDTrainProcess
@@ -59,6 +60,9 @@ class SDTrainer(BaseSDTrainProcess):
self.sd.vae.to('cpu') self.sd.vae.to('cpu')
flush() flush()
self.sd.noise_scheduler.set_timesteps(1000)
add_all_snr_to_noise_scheduler(self.sd.noise_scheduler, self.device_torch)
# you can expand these in a child class to make customization easier # you can expand these in a child class to make customization easier
def calculate_loss( def calculate_loss(
self, self,
@@ -145,7 +149,9 @@ class SDTrainer(BaseSDTrainProcess):
loss = loss.mean([1, 2, 3]) loss = loss.mean([1, 2, 3])
if self.train_config.learnable_snr_gos:
# add snr_gamma
loss = apply_learnable_snr_gos(loss, timesteps, self.snr_gos)
if self.train_config.snr_gamma is not None and self.train_config.snr_gamma > 0.000001 and not ignore_snr: if self.train_config.snr_gamma is not None and self.train_config.snr_gamma > 0.000001 and not ignore_snr:
# add snr_gamma # add snr_gamma
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.snr_gamma, fixed=True) loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.snr_gamma, fixed=True)
@@ -315,14 +321,20 @@ class SDTrainer(BaseSDTrainProcess):
# activate network if it exits # activate network if it exits
# make the batch splits prompts_1 = conditioned_prompts
prompts_2 = None
if self.train_config.short_and_long_captions_encoder_split and self.sd.is_xl:
prompts_1 = batch.get_caption_short_list()
prompts_2 = conditioned_prompts
# make the batch splits
if self.train_config.single_item_batching: if self.train_config.single_item_batching:
batch_size = noisy_latents.shape[0] batch_size = noisy_latents.shape[0]
# chunk/split everything # chunk/split everything
noisy_latents_list = torch.chunk(noisy_latents, batch_size, dim=0) noisy_latents_list = torch.chunk(noisy_latents, batch_size, dim=0)
noise_list = torch.chunk(noise, batch_size, dim=0) noise_list = torch.chunk(noise, batch_size, dim=0)
timesteps_list = torch.chunk(timesteps, batch_size, dim=0) timesteps_list = torch.chunk(timesteps, batch_size, dim=0)
conditioned_prompts_list = [[prompt] for prompt in conditioned_prompts] conditioned_prompts_list = [[prompt] for prompt in prompts_1]
if imgs is not None: if imgs is not None:
imgs_list = torch.chunk(imgs, batch_size, dim=0) imgs_list = torch.chunk(imgs, batch_size, dim=0)
else: else:
@@ -332,32 +344,44 @@ class SDTrainer(BaseSDTrainProcess):
else: else:
adapter_images_list = [None for _ in range(batch_size)] adapter_images_list = [None for _ in range(batch_size)]
mask_multiplier_list = torch.chunk(mask_multiplier, batch_size, dim=0) mask_multiplier_list = torch.chunk(mask_multiplier, batch_size, dim=0)
if prompts_2 is None:
prompt_2_list = [None for _ in range(batch_size)]
else:
prompt_2_list = [[prompt] for prompt in prompts_2]
else: else:
# but it all in an array # but it all in an array
noisy_latents_list = [noisy_latents] noisy_latents_list = [noisy_latents]
noise_list = [noise] noise_list = [noise]
timesteps_list = [timesteps] timesteps_list = [timesteps]
conditioned_prompts_list = [conditioned_prompts] conditioned_prompts_list = [prompts_1]
imgs_list = [imgs] imgs_list = [imgs]
adapter_images_list = [adapter_images] adapter_images_list = [adapter_images]
mask_multiplier_list = [mask_multiplier] mask_multiplier_list = [mask_multiplier]
if prompts_2 is None:
prompt_2_list = [None]
else:
prompt_2_list = [prompts_2]
for noisy_latents, noise, timesteps, conditioned_prompts, imgs, adapter_images, mask_multiplier in zip(
for noisy_latents, noise, timesteps, conditioned_prompts, imgs, adapter_images, mask_multiplier, prompt_2 in zip(
noisy_latents_list, noisy_latents_list,
noise_list, noise_list,
timesteps_list, timesteps_list,
conditioned_prompts_list, conditioned_prompts_list,
imgs_list, imgs_list,
adapter_images_list, adapter_images_list,
mask_multiplier_list mask_multiplier_list,
prompt_2_list
): ):
with network: with network:
with self.timer('encode_prompt'): with self.timer('encode_prompt'):
if grad_on_text_encoder: if grad_on_text_encoder:
with torch.set_grad_enabled(True): with torch.set_grad_enabled(True):
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, long_prompts=True).to( conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=True).to(
# conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=False).to(
self.device_torch, self.device_torch,
dtype=dtype) dtype=dtype)
else: else:
@@ -368,7 +392,8 @@ class SDTrainer(BaseSDTrainProcess):
te.eval() te.eval()
else: else:
self.sd.text_encoder.eval() self.sd.text_encoder.eval()
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, long_prompts=True).to( conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=True).to(
# conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=False).to(
self.device_torch, self.device_torch,
dtype=dtype) dtype=dtype)

View File

@@ -1,6 +1,7 @@
import copy import copy
import glob import glob
import inspect import inspect
import json
from collections import OrderedDict from collections import OrderedDict
import os import os
from typing import Union, List from typing import Union, List
@@ -36,7 +37,7 @@ from toolkit.stable_diffusion_model import StableDiffusion
from jobs.process import BaseTrainProcess from jobs.process import BaseTrainProcess
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_base_model_info_to_meta from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_base_model_info_to_meta
from toolkit.train_tools import get_torch_dtype from toolkit.train_tools import get_torch_dtype, LearnableSNRGamma
import gc import gc
from tqdm import tqdm from tqdm import tqdm
@@ -158,6 +159,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.named_lora = False self.named_lora = False
if self.embed_config is not None or is_training_adapter: if self.embed_config is not None or is_training_adapter:
self.named_lora = True self.named_lora = True
self.snr_gos: Union[LearnableSNRGamma, None] = None
def post_process_generate_image_config_list(self, generate_image_config_list: List[GenerateImageConfig]): def post_process_generate_image_config_list(self, generate_image_config_list: List[GenerateImageConfig]):
# override in subclass # override in subclass
@@ -370,6 +372,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
get_torch_dtype(self.save_config.dtype) get_torch_dtype(self.save_config.dtype)
) )
# save learnable params as json if we have thim
if self.snr_gos:
json_data = {
'offset': self.snr_gos.offset.item(),
'scale': self.snr_gos.scale.item(),
'gamma': self.snr_gos.gamma.item(),
}
path_to_save = file_path = os.path.join(self.save_root, 'learnable_snr.json')
with open(path_to_save, 'w') as f:
json.dump(json_data, f, indent=4)
self.print(f"Saved to {file_path}") self.print(f"Saved to {file_path}")
self.clean_up_saves() self.clean_up_saves()
self.post_save_hook(file_path) self.post_save_hook(file_path)
@@ -789,6 +802,19 @@ class BaseSDTrainProcess(BaseTrainProcess):
vae = vae.to(torch.device('cpu'), dtype=dtype) vae = vae.to(torch.device('cpu'), dtype=dtype)
vae.requires_grad_(False) vae.requires_grad_(False)
vae.eval() vae.eval()
if self.train_config.learnable_snr_gos:
self.snr_gos = LearnableSNRGamma(
self.sd.noise_scheduler, device=self.device_torch
)
# check to see if previous settings exist
path_to_load = os.path.join(self.save_root, 'learnable_snr.json')
if os.path.exists(path_to_load):
with open(path_to_load, 'r') as f:
json_data = json.load(f)
self.snr_gos.offset.data = torch.tensor(json_data['offset'], device=self.device_torch)
self.snr_gos.scale.data = torch.tensor(json_data['scale'], device=self.device_torch)
self.snr_gos.gamma.data = torch.tensor(json_data['gamma'], device=self.device_torch)
flush() flush()
### HOOk ### ### HOOk ###

View File

@@ -2,7 +2,7 @@ torch
torchvision torchvision
safetensors safetensors
diffusers==0.21.3 diffusers==0.21.3
git+https://github.com/huggingface/transformers.git@master git+https://github.com/huggingface/transformers.git
lycoris-lora==1.8.3 lycoris-lora==1.8.3
flatten_json flatten_json
pyyaml pyyaml

View File

@@ -169,6 +169,9 @@ class TrainConfig:
self.train_text_encoder = kwargs.get('train_text_encoder', True) self.train_text_encoder = kwargs.get('train_text_encoder', True)
self.min_snr_gamma = kwargs.get('min_snr_gamma', None) self.min_snr_gamma = kwargs.get('min_snr_gamma', None)
self.snr_gamma = kwargs.get('snr_gamma', None) self.snr_gamma = kwargs.get('snr_gamma', None)
# trains a gamma, offset, and scale to adjust loss to adapt to timestep differentials
# this should balance the learning rate across all timesteps over time
self.learnable_snr_gos = kwargs.get('learnable_snr_gos', False)
self.noise_offset = kwargs.get('noise_offset', 0.0) self.noise_offset = kwargs.get('noise_offset', 0.0)
self.skip_first_sample = kwargs.get('skip_first_sample', False) self.skip_first_sample = kwargs.get('skip_first_sample', False)
self.gradient_checkpointing = kwargs.get('gradient_checkpointing', True) self.gradient_checkpointing = kwargs.get('gradient_checkpointing', True)
@@ -190,6 +193,8 @@ class TrainConfig:
# Double up every image and run it through with both short and long captions. The idea # Double up every image and run it through with both short and long captions. The idea
# is that the network will learn how to generate good images with both short and long captions # is that the network will learn how to generate good images with both short and long captions
self.short_and_long_captions = kwargs.get('short_and_long_captions', False) self.short_and_long_captions = kwargs.get('short_and_long_captions', False)
# if above is NOT true, this will make it so the long caption foes to te2 and the short caption goes to te1 for sdxl only
self.short_and_long_captions_encoder_split = kwargs.get('short_and_long_captions_encoder_split', False)
# basically gradient accumulation but we run just 1 item through the network # basically gradient accumulation but we run just 1 item through the network
# and accumulate gradients. This can be used as basic gradient accumulation but is very helpful # and accumulate gradients. This can be used as basic gradient accumulation but is very helpful

View File

@@ -46,6 +46,8 @@ def get_optimizer(
if lower_type == "adam8bit": if lower_type == "adam8bit":
return bitsandbytes.optim.Adam8bit(params, lr=learning_rate, **optimizer_params) return bitsandbytes.optim.Adam8bit(params, lr=learning_rate, **optimizer_params)
elif lower_type == "adamw8bit":
return bitsandbytes.optim.AdamW8bit(params, lr=learning_rate, **optimizer_params)
elif lower_type == "lion8bit": elif lower_type == "lion8bit":
return bitsandbytes.optim.Lion8bit(params, lr=learning_rate, **optimizer_params) return bitsandbytes.optim.Lion8bit(params, lr=learning_rate, **optimizer_params)
else: else:

View File

@@ -683,6 +683,68 @@ def get_all_snr(noise_scheduler, device):
all_snr.requires_grad = False all_snr.requires_grad = False
return all_snr.to(device) return all_snr.to(device)
class LearnableSNRGamma:
"""
This is a trainer for learnable snr gamma
It will adapt to the dataset and attempt to adjust the snr multiplier to balance the loss over the timesteps
"""
def __init__(self, noise_scheduler: Union['DDPMScheduler'], device='cuda'):
self.device = device
self.noise_scheduler: Union['DDPMScheduler'] = noise_scheduler
self.offset = torch.nn.Parameter(torch.tensor(1.0, dtype=torch.float32, device=device))
self.scale = torch.nn.Parameter(torch.tensor(0.001, dtype=torch.float32, device=device))
self.gamma = torch.nn.Parameter(torch.tensor(1.0, dtype=torch.float32, device=device))
self.optimizer = torch.optim.AdamW([self.offset, self.gamma, self.scale], lr=0.1)
self.buffer = []
self.max_buffer_size = 100
def forward(self, loss, timesteps):
# do a our train loop for lsnr here and return our values detached
loss = loss.detach()
with torch.no_grad():
loss_chunks = torch.chunk(loss, loss.shape[0], dim=0)
for loss_chunk in loss_chunks:
self.buffer.append(loss_chunk.mean().detach())
if len(self.buffer) > self.max_buffer_size:
self.buffer.pop(0)
all_snr = get_all_snr(self.noise_scheduler, loss.device)
snr: torch.Tensor = torch.stack([all_snr[t] for t in timesteps]).detach().float().to(loss.device)
base_snrs = snr.clone().detach()
snr.requires_grad = True
snr = snr * self.scale + self.offset
gamma_over_snr = torch.div(torch.ones_like(snr) * self.gamma, snr)
snr_weight = torch.abs(gamma_over_snr).float().to(loss.device) # directly using gamma over snr
snr_adjusted_loss = loss * snr_weight
with torch.no_grad():
target = torch.mean(torch.stack(self.buffer)).detach()
# local_loss = torch.mean(torch.abs(snr_adjusted_loss - target))
squared_differences = (snr_adjusted_loss - target) ** 2
local_loss = torch.mean(squared_differences)
local_loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
return base_snrs, self.gamma.detach(), self.offset.detach(), self.scale.detach()
def apply_learnable_snr_gos(
loss,
timesteps,
learnable_snr_trainer:LearnableSNRGamma
):
snr, gamma, offset, scale = learnable_snr_trainer.forward(loss, timesteps)
snr = snr * scale + offset
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
snr_weight = torch.abs(gamma_over_snr).float().to(loss.device) # directly using gamma over snr
snr_adjusted_loss = loss * snr_weight
return snr_adjusted_loss
def apply_snr_weight( def apply_snr_weight(
loss, loss,
@@ -700,5 +762,6 @@ def apply_snr_weight(
snr_weight = gamma_over_snr.float().to(loss.device) # directly using gamma over snr snr_weight = gamma_over_snr.float().to(loss.device) # directly using gamma over snr
else: else:
snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float().to(loss.device) snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float().to(loss.device)
loss = loss * snr_weight snr_adjusted_loss = loss * snr_weight
return loss
return snr_adjusted_loss