mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added some split prompting started code, adamw8bit, replacements improving, learnable snr gos. A lot of good stuff.
This commit is contained in:
@@ -12,7 +12,7 @@ from .tools.dataset_tools_config_modules import RAW_DIR, TRAIN_DIR, Step, ImgInf
|
||||
from .tools.fuyu_utils import FuyuImageProcessor
|
||||
from .tools.image_tools import load_image, ImageProcessor, resize_to_max
|
||||
from .tools.llava_utils import LLaVAImageProcessor
|
||||
from .tools.caption import default_long_prompt, default_short_prompt
|
||||
from .tools.caption import default_long_prompt, default_short_prompt, default_replacements
|
||||
from jobs.process import BaseExtensionProcess
|
||||
from .tools.sync_tools import get_img_paths
|
||||
|
||||
@@ -39,6 +39,8 @@ class SuperTagger(BaseExtensionProcess):
|
||||
self.caption_prompt = config.get('caption_prompt', default_long_prompt)
|
||||
self.caption_short_prompt = config.get('caption_short_prompt', default_short_prompt)
|
||||
self.force_reprocess_img = config.get('force_reprocess_img', False)
|
||||
self.caption_replacements = config.get('caption_replacements', default_replacements)
|
||||
self.caption_short_replacements = config.get('caption_short_replacements', default_replacements)
|
||||
self.master_dataset_dict = OrderedDict()
|
||||
self.dataset_master_config_file = config.get('dataset_master_config_file', None)
|
||||
if parent_dir is not None and len(self.dataset_paths) == 0:
|
||||
@@ -118,7 +120,8 @@ class SuperTagger(BaseExtensionProcess):
|
||||
|
||||
img_info.caption = self.image_processor.generate_caption(
|
||||
image=caption_image,
|
||||
prompt=self.caption_prompt
|
||||
prompt=self.caption_prompt,
|
||||
replacements=self.caption_replacements
|
||||
)
|
||||
img_info.mark_step_complete(step)
|
||||
elif step == 'caption_short':
|
||||
@@ -134,7 +137,8 @@ class SuperTagger(BaseExtensionProcess):
|
||||
self.image_processor.load_model()
|
||||
img_info.caption_short = self.image_processor.generate_caption(
|
||||
image=caption_image,
|
||||
prompt=self.caption_short_prompt
|
||||
prompt=self.caption_short_prompt,
|
||||
replacements=self.caption_short_replacements
|
||||
)
|
||||
img_info.mark_step_complete(step)
|
||||
elif step == 'contrast_stretch':
|
||||
|
||||
@@ -33,7 +33,13 @@ def clean_caption(cap, replacements=None):
|
||||
cap = " ".join(cap.split())
|
||||
|
||||
for replacement in replacements:
|
||||
cap = cap.replace(replacement[0], replacement[1])
|
||||
if replacement[0].startswith('*'):
|
||||
# we are removing all text if it starts with this and the rest matches
|
||||
search_text = replacement[0][1:]
|
||||
if cap.startswith(search_text):
|
||||
cap = ""
|
||||
else:
|
||||
cap = cap.replace(replacement[0].lower(), replacement[1].lower())
|
||||
|
||||
cap_list = cap.split(",")
|
||||
# trim whitespace
|
||||
|
||||
@@ -77,7 +77,7 @@ class LLaVAImageProcessor:
|
||||
output_ids = self.model.generate(
|
||||
input_ids, images=image_tensor, do_sample=True, temperature=0.1,
|
||||
max_new_tokens=max_new_tokens, use_cache=True, stopping_criteria=[stopping_criteria],
|
||||
top_p=0.9
|
||||
top_p=0.8
|
||||
)
|
||||
outputs = self.tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
|
||||
conv.messages[-1][-1] = outputs
|
||||
|
||||
@@ -6,7 +6,8 @@ from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||
from toolkit.ip_adapter import IPAdapter
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork
|
||||
from toolkit.train_tools import get_torch_dtype, apply_snr_weight
|
||||
from toolkit.train_tools import get_torch_dtype, apply_snr_weight, add_all_snr_to_noise_scheduler, \
|
||||
apply_learnable_snr_gos, LearnableSNRGamma
|
||||
import gc
|
||||
import torch
|
||||
from jobs.process import BaseSDTrainProcess
|
||||
@@ -59,6 +60,9 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.sd.vae.to('cpu')
|
||||
flush()
|
||||
|
||||
self.sd.noise_scheduler.set_timesteps(1000)
|
||||
add_all_snr_to_noise_scheduler(self.sd.noise_scheduler, self.device_torch)
|
||||
|
||||
# you can expand these in a child class to make customization easier
|
||||
def calculate_loss(
|
||||
self,
|
||||
@@ -145,7 +149,9 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
|
||||
if self.train_config.learnable_snr_gos:
|
||||
# add snr_gamma
|
||||
loss = apply_learnable_snr_gos(loss, timesteps, self.snr_gos)
|
||||
if self.train_config.snr_gamma is not None and self.train_config.snr_gamma > 0.000001 and not ignore_snr:
|
||||
# add snr_gamma
|
||||
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.snr_gamma, fixed=True)
|
||||
@@ -315,14 +321,20 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
# activate network if it exits
|
||||
|
||||
# make the batch splits
|
||||
prompts_1 = conditioned_prompts
|
||||
prompts_2 = None
|
||||
if self.train_config.short_and_long_captions_encoder_split and self.sd.is_xl:
|
||||
prompts_1 = batch.get_caption_short_list()
|
||||
prompts_2 = conditioned_prompts
|
||||
|
||||
# make the batch splits
|
||||
if self.train_config.single_item_batching:
|
||||
batch_size = noisy_latents.shape[0]
|
||||
# chunk/split everything
|
||||
noisy_latents_list = torch.chunk(noisy_latents, batch_size, dim=0)
|
||||
noise_list = torch.chunk(noise, batch_size, dim=0)
|
||||
timesteps_list = torch.chunk(timesteps, batch_size, dim=0)
|
||||
conditioned_prompts_list = [[prompt] for prompt in conditioned_prompts]
|
||||
conditioned_prompts_list = [[prompt] for prompt in prompts_1]
|
||||
if imgs is not None:
|
||||
imgs_list = torch.chunk(imgs, batch_size, dim=0)
|
||||
else:
|
||||
@@ -332,32 +344,44 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
else:
|
||||
adapter_images_list = [None for _ in range(batch_size)]
|
||||
mask_multiplier_list = torch.chunk(mask_multiplier, batch_size, dim=0)
|
||||
if prompts_2 is None:
|
||||
prompt_2_list = [None for _ in range(batch_size)]
|
||||
else:
|
||||
prompt_2_list = [[prompt] for prompt in prompts_2]
|
||||
|
||||
else:
|
||||
# but it all in an array
|
||||
noisy_latents_list = [noisy_latents]
|
||||
noise_list = [noise]
|
||||
timesteps_list = [timesteps]
|
||||
conditioned_prompts_list = [conditioned_prompts]
|
||||
conditioned_prompts_list = [prompts_1]
|
||||
imgs_list = [imgs]
|
||||
adapter_images_list = [adapter_images]
|
||||
mask_multiplier_list = [mask_multiplier]
|
||||
if prompts_2 is None:
|
||||
prompt_2_list = [None]
|
||||
else:
|
||||
prompt_2_list = [prompts_2]
|
||||
|
||||
for noisy_latents, noise, timesteps, conditioned_prompts, imgs, adapter_images, mask_multiplier in zip(
|
||||
|
||||
|
||||
for noisy_latents, noise, timesteps, conditioned_prompts, imgs, adapter_images, mask_multiplier, prompt_2 in zip(
|
||||
noisy_latents_list,
|
||||
noise_list,
|
||||
timesteps_list,
|
||||
conditioned_prompts_list,
|
||||
imgs_list,
|
||||
adapter_images_list,
|
||||
mask_multiplier_list
|
||||
mask_multiplier_list,
|
||||
prompt_2_list
|
||||
):
|
||||
|
||||
with network:
|
||||
with self.timer('encode_prompt'):
|
||||
if grad_on_text_encoder:
|
||||
with torch.set_grad_enabled(True):
|
||||
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, long_prompts=True).to(
|
||||
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=True).to(
|
||||
# conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=False).to(
|
||||
self.device_torch,
|
||||
dtype=dtype)
|
||||
else:
|
||||
@@ -368,7 +392,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
te.eval()
|
||||
else:
|
||||
self.sd.text_encoder.eval()
|
||||
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, long_prompts=True).to(
|
||||
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=True).to(
|
||||
# conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=False).to(
|
||||
self.device_torch,
|
||||
dtype=dtype)
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import copy
|
||||
import glob
|
||||
import inspect
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
from typing import Union, List
|
||||
@@ -36,7 +37,7 @@ from toolkit.stable_diffusion_model import StableDiffusion
|
||||
|
||||
from jobs.process import BaseTrainProcess
|
||||
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_base_model_info_to_meta
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
from toolkit.train_tools import get_torch_dtype, LearnableSNRGamma
|
||||
import gc
|
||||
|
||||
from tqdm import tqdm
|
||||
@@ -158,6 +159,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.named_lora = False
|
||||
if self.embed_config is not None or is_training_adapter:
|
||||
self.named_lora = True
|
||||
self.snr_gos: Union[LearnableSNRGamma, None] = None
|
||||
|
||||
def post_process_generate_image_config_list(self, generate_image_config_list: List[GenerateImageConfig]):
|
||||
# override in subclass
|
||||
@@ -370,6 +372,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
get_torch_dtype(self.save_config.dtype)
|
||||
)
|
||||
|
||||
# save learnable params as json if we have thim
|
||||
if self.snr_gos:
|
||||
json_data = {
|
||||
'offset': self.snr_gos.offset.item(),
|
||||
'scale': self.snr_gos.scale.item(),
|
||||
'gamma': self.snr_gos.gamma.item(),
|
||||
}
|
||||
path_to_save = file_path = os.path.join(self.save_root, 'learnable_snr.json')
|
||||
with open(path_to_save, 'w') as f:
|
||||
json.dump(json_data, f, indent=4)
|
||||
|
||||
self.print(f"Saved to {file_path}")
|
||||
self.clean_up_saves()
|
||||
self.post_save_hook(file_path)
|
||||
@@ -789,6 +802,19 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
vae = vae.to(torch.device('cpu'), dtype=dtype)
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
if self.train_config.learnable_snr_gos:
|
||||
self.snr_gos = LearnableSNRGamma(
|
||||
self.sd.noise_scheduler, device=self.device_torch
|
||||
)
|
||||
# check to see if previous settings exist
|
||||
path_to_load = os.path.join(self.save_root, 'learnable_snr.json')
|
||||
if os.path.exists(path_to_load):
|
||||
with open(path_to_load, 'r') as f:
|
||||
json_data = json.load(f)
|
||||
self.snr_gos.offset.data = torch.tensor(json_data['offset'], device=self.device_torch)
|
||||
self.snr_gos.scale.data = torch.tensor(json_data['scale'], device=self.device_torch)
|
||||
self.snr_gos.gamma.data = torch.tensor(json_data['gamma'], device=self.device_torch)
|
||||
|
||||
flush()
|
||||
|
||||
### HOOk ###
|
||||
|
||||
@@ -2,7 +2,7 @@ torch
|
||||
torchvision
|
||||
safetensors
|
||||
diffusers==0.21.3
|
||||
git+https://github.com/huggingface/transformers.git@master
|
||||
git+https://github.com/huggingface/transformers.git
|
||||
lycoris-lora==1.8.3
|
||||
flatten_json
|
||||
pyyaml
|
||||
|
||||
@@ -169,6 +169,9 @@ class TrainConfig:
|
||||
self.train_text_encoder = kwargs.get('train_text_encoder', True)
|
||||
self.min_snr_gamma = kwargs.get('min_snr_gamma', None)
|
||||
self.snr_gamma = kwargs.get('snr_gamma', None)
|
||||
# trains a gamma, offset, and scale to adjust loss to adapt to timestep differentials
|
||||
# this should balance the learning rate across all timesteps over time
|
||||
self.learnable_snr_gos = kwargs.get('learnable_snr_gos', False)
|
||||
self.noise_offset = kwargs.get('noise_offset', 0.0)
|
||||
self.skip_first_sample = kwargs.get('skip_first_sample', False)
|
||||
self.gradient_checkpointing = kwargs.get('gradient_checkpointing', True)
|
||||
@@ -190,6 +193,8 @@ class TrainConfig:
|
||||
# Double up every image and run it through with both short and long captions. The idea
|
||||
# is that the network will learn how to generate good images with both short and long captions
|
||||
self.short_and_long_captions = kwargs.get('short_and_long_captions', False)
|
||||
# if above is NOT true, this will make it so the long caption foes to te2 and the short caption goes to te1 for sdxl only
|
||||
self.short_and_long_captions_encoder_split = kwargs.get('short_and_long_captions_encoder_split', False)
|
||||
|
||||
# basically gradient accumulation but we run just 1 item through the network
|
||||
# and accumulate gradients. This can be used as basic gradient accumulation but is very helpful
|
||||
|
||||
@@ -46,6 +46,8 @@ def get_optimizer(
|
||||
|
||||
if lower_type == "adam8bit":
|
||||
return bitsandbytes.optim.Adam8bit(params, lr=learning_rate, **optimizer_params)
|
||||
elif lower_type == "adamw8bit":
|
||||
return bitsandbytes.optim.AdamW8bit(params, lr=learning_rate, **optimizer_params)
|
||||
elif lower_type == "lion8bit":
|
||||
return bitsandbytes.optim.Lion8bit(params, lr=learning_rate, **optimizer_params)
|
||||
else:
|
||||
|
||||
@@ -683,6 +683,68 @@ def get_all_snr(noise_scheduler, device):
|
||||
all_snr.requires_grad = False
|
||||
return all_snr.to(device)
|
||||
|
||||
class LearnableSNRGamma:
|
||||
"""
|
||||
This is a trainer for learnable snr gamma
|
||||
It will adapt to the dataset and attempt to adjust the snr multiplier to balance the loss over the timesteps
|
||||
"""
|
||||
def __init__(self, noise_scheduler: Union['DDPMScheduler'], device='cuda'):
|
||||
self.device = device
|
||||
self.noise_scheduler: Union['DDPMScheduler'] = noise_scheduler
|
||||
self.offset = torch.nn.Parameter(torch.tensor(1.0, dtype=torch.float32, device=device))
|
||||
self.scale = torch.nn.Parameter(torch.tensor(0.001, dtype=torch.float32, device=device))
|
||||
self.gamma = torch.nn.Parameter(torch.tensor(1.0, dtype=torch.float32, device=device))
|
||||
self.optimizer = torch.optim.AdamW([self.offset, self.gamma, self.scale], lr=0.1)
|
||||
self.buffer = []
|
||||
self.max_buffer_size = 100
|
||||
|
||||
def forward(self, loss, timesteps):
|
||||
# do a our train loop for lsnr here and return our values detached
|
||||
loss = loss.detach()
|
||||
with torch.no_grad():
|
||||
loss_chunks = torch.chunk(loss, loss.shape[0], dim=0)
|
||||
for loss_chunk in loss_chunks:
|
||||
self.buffer.append(loss_chunk.mean().detach())
|
||||
if len(self.buffer) > self.max_buffer_size:
|
||||
self.buffer.pop(0)
|
||||
all_snr = get_all_snr(self.noise_scheduler, loss.device)
|
||||
snr: torch.Tensor = torch.stack([all_snr[t] for t in timesteps]).detach().float().to(loss.device)
|
||||
base_snrs = snr.clone().detach()
|
||||
snr.requires_grad = True
|
||||
snr = snr * self.scale + self.offset
|
||||
|
||||
gamma_over_snr = torch.div(torch.ones_like(snr) * self.gamma, snr)
|
||||
snr_weight = torch.abs(gamma_over_snr).float().to(loss.device) # directly using gamma over snr
|
||||
snr_adjusted_loss = loss * snr_weight
|
||||
with torch.no_grad():
|
||||
target = torch.mean(torch.stack(self.buffer)).detach()
|
||||
|
||||
# local_loss = torch.mean(torch.abs(snr_adjusted_loss - target))
|
||||
squared_differences = (snr_adjusted_loss - target) ** 2
|
||||
local_loss = torch.mean(squared_differences)
|
||||
local_loss.backward()
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
return base_snrs, self.gamma.detach(), self.offset.detach(), self.scale.detach()
|
||||
|
||||
|
||||
def apply_learnable_snr_gos(
|
||||
loss,
|
||||
timesteps,
|
||||
learnable_snr_trainer:LearnableSNRGamma
|
||||
):
|
||||
|
||||
snr, gamma, offset, scale = learnable_snr_trainer.forward(loss, timesteps)
|
||||
|
||||
snr = snr * scale + offset
|
||||
|
||||
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
|
||||
snr_weight = torch.abs(gamma_over_snr).float().to(loss.device) # directly using gamma over snr
|
||||
snr_adjusted_loss = loss * snr_weight
|
||||
|
||||
return snr_adjusted_loss
|
||||
|
||||
|
||||
def apply_snr_weight(
|
||||
loss,
|
||||
@@ -700,5 +762,6 @@ def apply_snr_weight(
|
||||
snr_weight = gamma_over_snr.float().to(loss.device) # directly using gamma over snr
|
||||
else:
|
||||
snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float().to(loss.device)
|
||||
loss = loss * snr_weight
|
||||
return loss
|
||||
snr_adjusted_loss = loss * snr_weight
|
||||
|
||||
return snr_adjusted_loss
|
||||
|
||||
Reference in New Issue
Block a user