Added some split prompting started code, adamw8bit, replacements improving, learnable snr gos. A lot of good stuff.

This commit is contained in:
Jaret Burkett
2023-11-01 06:52:21 -06:00
parent 436a09430e
commit a899ec91c8
9 changed files with 149 additions and 18 deletions

View File

@@ -12,7 +12,7 @@ from .tools.dataset_tools_config_modules import RAW_DIR, TRAIN_DIR, Step, ImgInf
from .tools.fuyu_utils import FuyuImageProcessor
from .tools.image_tools import load_image, ImageProcessor, resize_to_max
from .tools.llava_utils import LLaVAImageProcessor
from .tools.caption import default_long_prompt, default_short_prompt
from .tools.caption import default_long_prompt, default_short_prompt, default_replacements
from jobs.process import BaseExtensionProcess
from .tools.sync_tools import get_img_paths
@@ -39,6 +39,8 @@ class SuperTagger(BaseExtensionProcess):
self.caption_prompt = config.get('caption_prompt', default_long_prompt)
self.caption_short_prompt = config.get('caption_short_prompt', default_short_prompt)
self.force_reprocess_img = config.get('force_reprocess_img', False)
self.caption_replacements = config.get('caption_replacements', default_replacements)
self.caption_short_replacements = config.get('caption_short_replacements', default_replacements)
self.master_dataset_dict = OrderedDict()
self.dataset_master_config_file = config.get('dataset_master_config_file', None)
if parent_dir is not None and len(self.dataset_paths) == 0:
@@ -118,7 +120,8 @@ class SuperTagger(BaseExtensionProcess):
img_info.caption = self.image_processor.generate_caption(
image=caption_image,
prompt=self.caption_prompt
prompt=self.caption_prompt,
replacements=self.caption_replacements
)
img_info.mark_step_complete(step)
elif step == 'caption_short':
@@ -134,7 +137,8 @@ class SuperTagger(BaseExtensionProcess):
self.image_processor.load_model()
img_info.caption_short = self.image_processor.generate_caption(
image=caption_image,
prompt=self.caption_short_prompt
prompt=self.caption_short_prompt,
replacements=self.caption_short_replacements
)
img_info.mark_step_complete(step)
elif step == 'contrast_stretch':

View File

@@ -33,7 +33,13 @@ def clean_caption(cap, replacements=None):
cap = " ".join(cap.split())
for replacement in replacements:
cap = cap.replace(replacement[0], replacement[1])
if replacement[0].startswith('*'):
# we are removing all text if it starts with this and the rest matches
search_text = replacement[0][1:]
if cap.startswith(search_text):
cap = ""
else:
cap = cap.replace(replacement[0].lower(), replacement[1].lower())
cap_list = cap.split(",")
# trim whitespace

View File

@@ -77,7 +77,7 @@ class LLaVAImageProcessor:
output_ids = self.model.generate(
input_ids, images=image_tensor, do_sample=True, temperature=0.1,
max_new_tokens=max_new_tokens, use_cache=True, stopping_criteria=[stopping_criteria],
top_p=0.9
top_p=0.8
)
outputs = self.tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
conv.messages[-1][-1] = outputs

View File

@@ -6,7 +6,8 @@ from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
from toolkit.ip_adapter import IPAdapter
from toolkit.prompt_utils import PromptEmbeds
from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork
from toolkit.train_tools import get_torch_dtype, apply_snr_weight
from toolkit.train_tools import get_torch_dtype, apply_snr_weight, add_all_snr_to_noise_scheduler, \
apply_learnable_snr_gos, LearnableSNRGamma
import gc
import torch
from jobs.process import BaseSDTrainProcess
@@ -59,6 +60,9 @@ class SDTrainer(BaseSDTrainProcess):
self.sd.vae.to('cpu')
flush()
self.sd.noise_scheduler.set_timesteps(1000)
add_all_snr_to_noise_scheduler(self.sd.noise_scheduler, self.device_torch)
# you can expand these in a child class to make customization easier
def calculate_loss(
self,
@@ -145,7 +149,9 @@ class SDTrainer(BaseSDTrainProcess):
loss = loss.mean([1, 2, 3])
if self.train_config.learnable_snr_gos:
# add snr_gamma
loss = apply_learnable_snr_gos(loss, timesteps, self.snr_gos)
if self.train_config.snr_gamma is not None and self.train_config.snr_gamma > 0.000001 and not ignore_snr:
# add snr_gamma
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.snr_gamma, fixed=True)
@@ -315,14 +321,20 @@ class SDTrainer(BaseSDTrainProcess):
# activate network if it exits
# make the batch splits
prompts_1 = conditioned_prompts
prompts_2 = None
if self.train_config.short_and_long_captions_encoder_split and self.sd.is_xl:
prompts_1 = batch.get_caption_short_list()
prompts_2 = conditioned_prompts
# make the batch splits
if self.train_config.single_item_batching:
batch_size = noisy_latents.shape[0]
# chunk/split everything
noisy_latents_list = torch.chunk(noisy_latents, batch_size, dim=0)
noise_list = torch.chunk(noise, batch_size, dim=0)
timesteps_list = torch.chunk(timesteps, batch_size, dim=0)
conditioned_prompts_list = [[prompt] for prompt in conditioned_prompts]
conditioned_prompts_list = [[prompt] for prompt in prompts_1]
if imgs is not None:
imgs_list = torch.chunk(imgs, batch_size, dim=0)
else:
@@ -332,32 +344,44 @@ class SDTrainer(BaseSDTrainProcess):
else:
adapter_images_list = [None for _ in range(batch_size)]
mask_multiplier_list = torch.chunk(mask_multiplier, batch_size, dim=0)
if prompts_2 is None:
prompt_2_list = [None for _ in range(batch_size)]
else:
prompt_2_list = [[prompt] for prompt in prompts_2]
else:
# but it all in an array
noisy_latents_list = [noisy_latents]
noise_list = [noise]
timesteps_list = [timesteps]
conditioned_prompts_list = [conditioned_prompts]
conditioned_prompts_list = [prompts_1]
imgs_list = [imgs]
adapter_images_list = [adapter_images]
mask_multiplier_list = [mask_multiplier]
if prompts_2 is None:
prompt_2_list = [None]
else:
prompt_2_list = [prompts_2]
for noisy_latents, noise, timesteps, conditioned_prompts, imgs, adapter_images, mask_multiplier in zip(
for noisy_latents, noise, timesteps, conditioned_prompts, imgs, adapter_images, mask_multiplier, prompt_2 in zip(
noisy_latents_list,
noise_list,
timesteps_list,
conditioned_prompts_list,
imgs_list,
adapter_images_list,
mask_multiplier_list
mask_multiplier_list,
prompt_2_list
):
with network:
with self.timer('encode_prompt'):
if grad_on_text_encoder:
with torch.set_grad_enabled(True):
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, long_prompts=True).to(
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=True).to(
# conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=False).to(
self.device_torch,
dtype=dtype)
else:
@@ -368,7 +392,8 @@ class SDTrainer(BaseSDTrainProcess):
te.eval()
else:
self.sd.text_encoder.eval()
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, long_prompts=True).to(
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=True).to(
# conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=False).to(
self.device_torch,
dtype=dtype)

View File

@@ -1,6 +1,7 @@
import copy
import glob
import inspect
import json
from collections import OrderedDict
import os
from typing import Union, List
@@ -36,7 +37,7 @@ from toolkit.stable_diffusion_model import StableDiffusion
from jobs.process import BaseTrainProcess
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_base_model_info_to_meta
from toolkit.train_tools import get_torch_dtype
from toolkit.train_tools import get_torch_dtype, LearnableSNRGamma
import gc
from tqdm import tqdm
@@ -158,6 +159,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.named_lora = False
if self.embed_config is not None or is_training_adapter:
self.named_lora = True
self.snr_gos: Union[LearnableSNRGamma, None] = None
def post_process_generate_image_config_list(self, generate_image_config_list: List[GenerateImageConfig]):
# override in subclass
@@ -370,6 +372,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
get_torch_dtype(self.save_config.dtype)
)
# save learnable params as json if we have thim
if self.snr_gos:
json_data = {
'offset': self.snr_gos.offset.item(),
'scale': self.snr_gos.scale.item(),
'gamma': self.snr_gos.gamma.item(),
}
path_to_save = file_path = os.path.join(self.save_root, 'learnable_snr.json')
with open(path_to_save, 'w') as f:
json.dump(json_data, f, indent=4)
self.print(f"Saved to {file_path}")
self.clean_up_saves()
self.post_save_hook(file_path)
@@ -789,6 +802,19 @@ class BaseSDTrainProcess(BaseTrainProcess):
vae = vae.to(torch.device('cpu'), dtype=dtype)
vae.requires_grad_(False)
vae.eval()
if self.train_config.learnable_snr_gos:
self.snr_gos = LearnableSNRGamma(
self.sd.noise_scheduler, device=self.device_torch
)
# check to see if previous settings exist
path_to_load = os.path.join(self.save_root, 'learnable_snr.json')
if os.path.exists(path_to_load):
with open(path_to_load, 'r') as f:
json_data = json.load(f)
self.snr_gos.offset.data = torch.tensor(json_data['offset'], device=self.device_torch)
self.snr_gos.scale.data = torch.tensor(json_data['scale'], device=self.device_torch)
self.snr_gos.gamma.data = torch.tensor(json_data['gamma'], device=self.device_torch)
flush()
### HOOk ###

View File

@@ -2,7 +2,7 @@ torch
torchvision
safetensors
diffusers==0.21.3
git+https://github.com/huggingface/transformers.git@master
git+https://github.com/huggingface/transformers.git
lycoris-lora==1.8.3
flatten_json
pyyaml

View File

@@ -169,6 +169,9 @@ class TrainConfig:
self.train_text_encoder = kwargs.get('train_text_encoder', True)
self.min_snr_gamma = kwargs.get('min_snr_gamma', None)
self.snr_gamma = kwargs.get('snr_gamma', None)
# trains a gamma, offset, and scale to adjust loss to adapt to timestep differentials
# this should balance the learning rate across all timesteps over time
self.learnable_snr_gos = kwargs.get('learnable_snr_gos', False)
self.noise_offset = kwargs.get('noise_offset', 0.0)
self.skip_first_sample = kwargs.get('skip_first_sample', False)
self.gradient_checkpointing = kwargs.get('gradient_checkpointing', True)
@@ -190,6 +193,8 @@ class TrainConfig:
# Double up every image and run it through with both short and long captions. The idea
# is that the network will learn how to generate good images with both short and long captions
self.short_and_long_captions = kwargs.get('short_and_long_captions', False)
# if above is NOT true, this will make it so the long caption foes to te2 and the short caption goes to te1 for sdxl only
self.short_and_long_captions_encoder_split = kwargs.get('short_and_long_captions_encoder_split', False)
# basically gradient accumulation but we run just 1 item through the network
# and accumulate gradients. This can be used as basic gradient accumulation but is very helpful

View File

@@ -46,6 +46,8 @@ def get_optimizer(
if lower_type == "adam8bit":
return bitsandbytes.optim.Adam8bit(params, lr=learning_rate, **optimizer_params)
elif lower_type == "adamw8bit":
return bitsandbytes.optim.AdamW8bit(params, lr=learning_rate, **optimizer_params)
elif lower_type == "lion8bit":
return bitsandbytes.optim.Lion8bit(params, lr=learning_rate, **optimizer_params)
else:

View File

@@ -683,6 +683,68 @@ def get_all_snr(noise_scheduler, device):
all_snr.requires_grad = False
return all_snr.to(device)
class LearnableSNRGamma:
"""
This is a trainer for learnable snr gamma
It will adapt to the dataset and attempt to adjust the snr multiplier to balance the loss over the timesteps
"""
def __init__(self, noise_scheduler: Union['DDPMScheduler'], device='cuda'):
self.device = device
self.noise_scheduler: Union['DDPMScheduler'] = noise_scheduler
self.offset = torch.nn.Parameter(torch.tensor(1.0, dtype=torch.float32, device=device))
self.scale = torch.nn.Parameter(torch.tensor(0.001, dtype=torch.float32, device=device))
self.gamma = torch.nn.Parameter(torch.tensor(1.0, dtype=torch.float32, device=device))
self.optimizer = torch.optim.AdamW([self.offset, self.gamma, self.scale], lr=0.1)
self.buffer = []
self.max_buffer_size = 100
def forward(self, loss, timesteps):
# do a our train loop for lsnr here and return our values detached
loss = loss.detach()
with torch.no_grad():
loss_chunks = torch.chunk(loss, loss.shape[0], dim=0)
for loss_chunk in loss_chunks:
self.buffer.append(loss_chunk.mean().detach())
if len(self.buffer) > self.max_buffer_size:
self.buffer.pop(0)
all_snr = get_all_snr(self.noise_scheduler, loss.device)
snr: torch.Tensor = torch.stack([all_snr[t] for t in timesteps]).detach().float().to(loss.device)
base_snrs = snr.clone().detach()
snr.requires_grad = True
snr = snr * self.scale + self.offset
gamma_over_snr = torch.div(torch.ones_like(snr) * self.gamma, snr)
snr_weight = torch.abs(gamma_over_snr).float().to(loss.device) # directly using gamma over snr
snr_adjusted_loss = loss * snr_weight
with torch.no_grad():
target = torch.mean(torch.stack(self.buffer)).detach()
# local_loss = torch.mean(torch.abs(snr_adjusted_loss - target))
squared_differences = (snr_adjusted_loss - target) ** 2
local_loss = torch.mean(squared_differences)
local_loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
return base_snrs, self.gamma.detach(), self.offset.detach(), self.scale.detach()
def apply_learnable_snr_gos(
loss,
timesteps,
learnable_snr_trainer:LearnableSNRGamma
):
snr, gamma, offset, scale = learnable_snr_trainer.forward(loss, timesteps)
snr = snr * scale + offset
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
snr_weight = torch.abs(gamma_over_snr).float().to(loss.device) # directly using gamma over snr
snr_adjusted_loss = loss * snr_weight
return snr_adjusted_loss
def apply_snr_weight(
loss,
@@ -700,5 +762,6 @@ def apply_snr_weight(
snr_weight = gamma_over_snr.float().to(loss.device) # directly using gamma over snr
else:
snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float().to(loss.device)
loss = loss * snr_weight
return loss
snr_adjusted_loss = loss * snr_weight
return snr_adjusted_loss