mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-24 22:33:56 +00:00
Added support for training lora, dreambooth, and fine tuning. Still need testing and docs
This commit is contained in:
174
extensions_built_in/sd_trainer/SDTrainer.py
Normal file
174
extensions_built_in/sd_trainer/SDTrainer.py
Normal file
@@ -0,0 +1,174 @@
|
||||
from collections import OrderedDict
|
||||
from torch.utils.data import DataLoader
|
||||
from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds
|
||||
from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork
|
||||
from toolkit.train_tools import get_torch_dtype, apply_snr_weight
|
||||
import gc
|
||||
import torch
|
||||
from jobs.process import BaseSDTrainProcess
|
||||
|
||||
|
||||
def flush():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
|
||||
class SDTrainer(BaseSDTrainProcess):
|
||||
sd: StableDiffusion
|
||||
data_loader: DataLoader = None
|
||||
|
||||
def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
|
||||
super().__init__(process_id, job, config, **kwargs)
|
||||
pass
|
||||
|
||||
def before_model_load(self):
|
||||
pass
|
||||
|
||||
def hook_before_train_loop(self):
|
||||
self.sd.vae.eval()
|
||||
self.sd.vae.to(self.device_torch)
|
||||
|
||||
# textual inversion
|
||||
if self.embedding is not None:
|
||||
# keep original embeddings as reference
|
||||
self.orig_embeds_params = self.sd.text_encoder.get_input_embeddings().weight.data.clone()
|
||||
# set text encoder to train. Not sure if this is necessary but diffusers example did it
|
||||
self.sd.text_encoder.train()
|
||||
|
||||
def hook_train_loop(self, batch):
|
||||
with torch.no_grad():
|
||||
imgs, prompts, dataset_config = batch
|
||||
|
||||
# convert the 0 or 1 for is reg to a bool list
|
||||
is_reg_list = dataset_config.get('is_reg', [0 for _ in range(imgs.shape[0])])
|
||||
if isinstance(is_reg_list, torch.Tensor):
|
||||
is_reg_list = is_reg_list.numpy().tolist()
|
||||
is_reg_list = [bool(x) for x in is_reg_list]
|
||||
|
||||
conditioned_prompts = []
|
||||
|
||||
for prompt, is_reg in zip(prompts, is_reg_list):
|
||||
|
||||
# make sure the embedding is in the prompts
|
||||
if self.embedding is not None:
|
||||
prompt = self.embedding.inject_embedding_to_prompt(
|
||||
prompt,
|
||||
expand_token=True,
|
||||
add_if_not_present=True,
|
||||
)
|
||||
|
||||
# make sure trigger is in the prompts if not a regularization run
|
||||
if self.trigger_word is not None and not is_reg:
|
||||
prompt = self.sd.inject_trigger_into_prompt(
|
||||
prompt,
|
||||
add_if_not_present=True,
|
||||
)
|
||||
conditioned_prompts.append(prompt)
|
||||
|
||||
batch_size = imgs.shape[0]
|
||||
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
imgs = imgs.to(self.device_torch, dtype=dtype)
|
||||
latents = self.sd.encode_images(imgs)
|
||||
|
||||
noise_scheduler = self.sd.noise_scheduler
|
||||
optimizer = self.optimizer
|
||||
lr_scheduler = self.lr_scheduler
|
||||
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
self.train_config.max_denoising_steps, device=self.device_torch
|
||||
)
|
||||
|
||||
timesteps = torch.randint(0, self.train_config.max_denoising_steps, (batch_size,), device=self.device_torch)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
# get noise
|
||||
noise = self.sd.get_latent_noise(
|
||||
pixel_height=imgs.shape[2],
|
||||
pixel_width=imgs.shape[3],
|
||||
batch_size=batch_size,
|
||||
noise_offset=self.train_config.noise_offset
|
||||
).to(self.device_torch, dtype=dtype)
|
||||
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# remove grads for these
|
||||
noisy_latents.requires_grad = False
|
||||
noise.requires_grad = False
|
||||
|
||||
flush()
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# text encoding
|
||||
grad_on_text_encoder = False
|
||||
if self.train_config.train_text_encoder:
|
||||
grad_on_text_encoder = True
|
||||
|
||||
if self.embedding:
|
||||
grad_on_text_encoder = True
|
||||
|
||||
# have a blank network so we can wrap it in a context and set multipliers without checking every time
|
||||
if self.network is not None:
|
||||
network = self.network
|
||||
else:
|
||||
network = BlankNetwork()
|
||||
|
||||
# activate network if it exits
|
||||
with network:
|
||||
with torch.set_grad_enabled(grad_on_text_encoder):
|
||||
embedding_list = []
|
||||
# embed the prompts
|
||||
for prompt in conditioned_prompts:
|
||||
embedding = self.sd.encode_prompt(prompt).to(self.device_torch, dtype=dtype)
|
||||
embedding_list.append(embedding)
|
||||
conditional_embeds = concat_prompt_embeds(embedding_list)
|
||||
|
||||
noise_pred = self.sd.predict_noise(
|
||||
latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
|
||||
noise = noise.to(self.device_torch, dtype=dtype)
|
||||
|
||||
if self.sd.prediction_type == 'v_prediction':
|
||||
# v-parameterization training
|
||||
target = noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
|
||||
else:
|
||||
target = noise
|
||||
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
|
||||
# add min_snr_gamma
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, self.train_config.min_snr_gamma)
|
||||
|
||||
loss = loss.mean()
|
||||
|
||||
# back propagate loss to free ram
|
||||
loss.backward()
|
||||
flush()
|
||||
|
||||
# apply gradients
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
lr_scheduler.step()
|
||||
|
||||
if self.embedding is not None:
|
||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||
index_no_updates = torch.ones((len(self.sd.tokenizer),), dtype=torch.bool)
|
||||
index_no_updates[
|
||||
min(self.embedding.placeholder_token_ids): max(self.embedding.placeholder_token_ids) + 1] = False
|
||||
with torch.no_grad():
|
||||
self.sd.text_encoder.get_input_embeddings().weight[
|
||||
index_no_updates
|
||||
] = self.orig_embeds_params[index_no_updates]
|
||||
|
||||
loss_dict = OrderedDict(
|
||||
{'loss': loss.item()}
|
||||
)
|
||||
|
||||
return loss_dict
|
||||
@@ -2,24 +2,29 @@
|
||||
from toolkit.extension import Extension
|
||||
|
||||
|
||||
# We make a subclass of Extension
|
||||
class OffsetSliderTrainer(Extension):
|
||||
# This is for generic training (LoRA, Dreambooth, FineTuning)
|
||||
class SDTrainerExtension(Extension):
|
||||
# uid must be unique, it is how the extension is identified
|
||||
uid = "textual_inversion_trainer"
|
||||
uid = "sd_trainer"
|
||||
|
||||
# name is the name of the extension for printing
|
||||
name = "Textual Inversion Trainer"
|
||||
name = "SD Trainer"
|
||||
|
||||
# This is where your process class is loaded
|
||||
# keep your imports in here so they don't slow down the rest of the program
|
||||
@classmethod
|
||||
def get_process(cls):
|
||||
# import your process class here so it is only loaded when needed and return it
|
||||
from .TextualInversionTrainer import TextualInversionTrainer
|
||||
return TextualInversionTrainer
|
||||
from .SDTrainer import SDTrainer
|
||||
return SDTrainer
|
||||
|
||||
|
||||
# for backwards compatability
|
||||
class TextualInversionTrainer(SDTrainerExtension):
|
||||
uid = "textual_inversion_trainer"
|
||||
|
||||
|
||||
AI_TOOLKIT_EXTENSIONS = [
|
||||
# you can put a list of extensions here
|
||||
OffsetSliderTrainer
|
||||
SDTrainerExtension, TextualInversionTrainer
|
||||
]
|
||||
@@ -1,152 +0,0 @@
|
||||
import copy
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
from typing import Optional, Union, List
|
||||
from torch.utils.data import ConcatDataset, DataLoader
|
||||
|
||||
from toolkit.config_modules import ReferenceDatasetConfig
|
||||
from toolkit.data_loader import PairedImageDataset, ImageDataset
|
||||
from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds
|
||||
from toolkit.stable_diffusion_model import StableDiffusion, PromptEmbeds
|
||||
from toolkit.train_tools import get_torch_dtype, apply_snr_weight, apply_noise_offset
|
||||
import gc
|
||||
from toolkit import train_tools
|
||||
import torch
|
||||
from jobs.process import BaseSDTrainProcess
|
||||
import random
|
||||
from toolkit.basic import value_map
|
||||
|
||||
|
||||
def flush():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
|
||||
class TextualInversionTrainer(BaseSDTrainProcess):
|
||||
sd: StableDiffusion
|
||||
data_loader: DataLoader = None
|
||||
|
||||
def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
|
||||
super().__init__(process_id, job, config, **kwargs)
|
||||
pass
|
||||
|
||||
def before_model_load(self):
|
||||
pass
|
||||
|
||||
def hook_before_train_loop(self):
|
||||
self.sd.vae.eval()
|
||||
self.sd.vae.to(self.device_torch)
|
||||
|
||||
# keep original embeddings as reference
|
||||
self.orig_embeds_params = self.sd.text_encoder.get_input_embeddings().weight.data.clone()
|
||||
# set text encoder to train. Not sure if this is necessary but diffusers example did it
|
||||
self.sd.text_encoder.train()
|
||||
pass
|
||||
|
||||
def hook_train_loop(self, batch):
|
||||
with torch.no_grad():
|
||||
imgs, prompts = batch
|
||||
|
||||
# very loosely based on this. very loosely
|
||||
# ref https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py
|
||||
|
||||
# make sure the embedding is in the prompts
|
||||
conditioned_prompts = [self.embedding.inject_embedding_to_prompt(
|
||||
x,
|
||||
expand_token=True,
|
||||
add_if_not_present=True,
|
||||
) for x in prompts]
|
||||
|
||||
batch_size = imgs.shape[0]
|
||||
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
imgs = imgs.to(self.device_torch, dtype=dtype)
|
||||
latents = self.sd.encode_images(imgs)
|
||||
|
||||
noise_scheduler = self.sd.noise_scheduler
|
||||
optimizer = self.optimizer
|
||||
lr_scheduler = self.lr_scheduler
|
||||
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
self.train_config.max_denoising_steps, device=self.device_torch
|
||||
)
|
||||
|
||||
timesteps = torch.randint(0, self.train_config.max_denoising_steps, (batch_size,), device=self.device_torch)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
# get noise
|
||||
noise = self.sd.get_latent_noise(
|
||||
pixel_height=imgs.shape[2],
|
||||
pixel_width=imgs.shape[3],
|
||||
batch_size=batch_size,
|
||||
noise_offset=self.train_config.noise_offset
|
||||
).to(self.device_torch, dtype=dtype)
|
||||
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# remove grads for these
|
||||
noisy_latents.requires_grad = False
|
||||
noise.requires_grad = False
|
||||
|
||||
flush()
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
noisy_latents.requires_grad = False
|
||||
|
||||
# text encoding
|
||||
embedding_list = []
|
||||
# embed the prompts
|
||||
for prompt in conditioned_prompts:
|
||||
embedding = self.sd.encode_prompt(prompt).to(self.device_torch, dtype=dtype)
|
||||
embedding_list.append(embedding)
|
||||
conditional_embeds = concat_prompt_embeds(embedding_list)
|
||||
|
||||
noise_pred = self.sd.predict_noise(
|
||||
latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
noise = noise.to(self.device_torch, dtype=dtype)
|
||||
|
||||
if self.sd.prediction_type == 'v_prediction':
|
||||
# v-parameterization training
|
||||
target = noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
|
||||
else:
|
||||
target = noise
|
||||
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
|
||||
# add min_snr_gamma
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, self.train_config.min_snr_gamma)
|
||||
|
||||
loss = loss.mean()
|
||||
|
||||
# back propagate loss to free ram
|
||||
loss.backward()
|
||||
flush()
|
||||
|
||||
# apply gradients
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
lr_scheduler.step()
|
||||
|
||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||
index_no_updates = torch.ones((len(self.sd.tokenizer),), dtype=torch.bool)
|
||||
index_no_updates[
|
||||
min(self.embedding.placeholder_token_ids): max(self.embedding.placeholder_token_ids) + 1] = False
|
||||
with torch.no_grad():
|
||||
self.sd.text_encoder.get_input_embeddings().weight[
|
||||
index_no_updates
|
||||
] = self.orig_embeds_params[index_no_updates]
|
||||
|
||||
loss_dict = OrderedDict(
|
||||
{'loss': loss.item()}
|
||||
)
|
||||
|
||||
return loss_dict
|
||||
# end hook_train_loop
|
||||
@@ -61,11 +61,23 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.optimizer = None
|
||||
self.lr_scheduler = None
|
||||
self.data_loader: Union[DataLoader, None] = None
|
||||
self.data_loader_reg: Union[DataLoader, None] = None
|
||||
self.trigger_word = self.get_conf('trigger_word', None)
|
||||
|
||||
raw_datasets = self.get_conf('datasets', None)
|
||||
self.datasets = None
|
||||
self.datasets_reg = None
|
||||
if raw_datasets is not None and len(raw_datasets) > 0:
|
||||
self.datasets = [DatasetConfig(**d) for d in raw_datasets]
|
||||
for raw_dataset in raw_datasets:
|
||||
dataset = DatasetConfig(**raw_dataset)
|
||||
if dataset.is_reg:
|
||||
if self.datasets_reg is None:
|
||||
self.datasets_reg = []
|
||||
self.datasets_reg.append(dataset)
|
||||
else:
|
||||
if self.datasets is None:
|
||||
self.datasets = []
|
||||
self.datasets.append(dataset)
|
||||
|
||||
self.embed_config = None
|
||||
embedding_raw = self.get_conf('embedding', None)
|
||||
@@ -112,6 +124,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
prompt = self.embedding.inject_embedding_to_prompt(
|
||||
prompt,
|
||||
)
|
||||
if self.trigger_word is not None:
|
||||
prompt = self.sd.inject_trigger_into_prompt(
|
||||
prompt, self.trigger_word
|
||||
)
|
||||
|
||||
gen_img_config_list.append(GenerateImageConfig(
|
||||
prompt=prompt, # it will autoparse the prompt
|
||||
@@ -275,6 +291,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# load datasets if passed in the root process
|
||||
if self.datasets is not None:
|
||||
self.data_loader = get_dataloader_from_datasets(self.datasets, self.train_config.batch_size)
|
||||
if self.datasets_reg is not None:
|
||||
self.data_loader_reg = get_dataloader_from_datasets(self.datasets_reg, self.train_config.batch_size)
|
||||
|
||||
### HOOK ###
|
||||
self.hook_before_model_load()
|
||||
@@ -433,14 +451,29 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
dataloader = None
|
||||
dataloader_iterator = None
|
||||
|
||||
if self.data_loader_reg is not None:
|
||||
dataloader_reg = self.data_loader_reg
|
||||
dataloader_iterator_reg = iter(dataloader_reg)
|
||||
else:
|
||||
dataloader_reg = None
|
||||
dataloader_iterator_reg = None
|
||||
|
||||
# self.step_num = 0
|
||||
for step in range(self.step_num, self.train_config.steps):
|
||||
if dataloader is not None:
|
||||
# if is even step and we have a reg dataset, use that
|
||||
# todo improve this logic to send one of each through if we can buckets and batch size might be an issue
|
||||
if step % 2 == 0 and dataloader_reg is not None:
|
||||
try:
|
||||
batch = next(dataloader_iterator_reg)
|
||||
except StopIteration:
|
||||
# hit the end of an epoch, reset
|
||||
dataloader_iterator_reg = iter(dataloader_reg)
|
||||
batch = next(dataloader_iterator_reg)
|
||||
elif dataloader is not None:
|
||||
try:
|
||||
batch = next(dataloader_iterator)
|
||||
except StopIteration:
|
||||
# hit the end of an epoch, reset
|
||||
# todo, should we do something else here? like blow up balloons?
|
||||
dataloader_iterator = iter(dataloader)
|
||||
batch = next(dataloader_iterator)
|
||||
else:
|
||||
|
||||
@@ -168,6 +168,7 @@ class DatasetConfig:
|
||||
self.resolution: int = kwargs.get('resolution', 512)
|
||||
self.scale: float = kwargs.get('scale', 1.0)
|
||||
self.buckets: bool = kwargs.get('buckets', False)
|
||||
self.is_reg: bool = kwargs.get('is_reg', False)
|
||||
|
||||
|
||||
class GenerateImageConfig:
|
||||
|
||||
@@ -356,11 +356,16 @@ class AiToolkitDataset(Dataset, CaptionMixin):
|
||||
|
||||
img = self.transform(img)
|
||||
|
||||
# todo convert it all
|
||||
dataset_config_dict = {
|
||||
"is_reg": 1 if self.dataset_config.is_reg else 0,
|
||||
}
|
||||
|
||||
if self.caption_type is not None:
|
||||
prompt = self.get_caption_item(index)
|
||||
return img, prompt
|
||||
return img, prompt, dataset_config_dict
|
||||
else:
|
||||
return img
|
||||
return img, dataset_config_dict
|
||||
|
||||
|
||||
def get_dataloader_from_datasets(dataset_options, batch_size=1):
|
||||
|
||||
@@ -515,7 +515,8 @@ class StableDiffusion:
|
||||
elif ts_bs * 2 == latent_model_input.shape[0]:
|
||||
timestep = torch.cat([timestep] * 2)
|
||||
else:
|
||||
raise ValueError(f"Batch size of latents {latent_model_input.shape[0]} must be the same or half the batch size of timesteps {timestep.shape[0]}")
|
||||
raise ValueError(
|
||||
f"Batch size of latents {latent_model_input.shape[0]} must be the same or half the batch size of timesteps {timestep.shape[0]}")
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
@@ -659,6 +660,39 @@ class StableDiffusion:
|
||||
|
||||
raise ValueError(f"Unknown weight name: {name}")
|
||||
|
||||
def inject_trigger_into_prompt(self, prompt, trigger=None, to_replace_list=None, add_if_not_present=True):
|
||||
if trigger is None:
|
||||
return prompt
|
||||
output_prompt = prompt
|
||||
default_replacements = ["[name]", "[trigger]"]
|
||||
|
||||
replace_with = trigger
|
||||
if to_replace_list is None:
|
||||
to_replace_list = default_replacements
|
||||
else:
|
||||
to_replace_list += default_replacements
|
||||
|
||||
# remove duplicates
|
||||
to_replace_list = list(set(to_replace_list))
|
||||
|
||||
# replace them all
|
||||
for to_replace in to_replace_list:
|
||||
# replace it
|
||||
output_prompt = output_prompt.replace(to_replace, replace_with)
|
||||
|
||||
# see how many times replace_with is in the prompt
|
||||
num_instances = prompt.count(replace_with)
|
||||
|
||||
if num_instances == 0 and add_if_not_present:
|
||||
# add it to the beginning of the prompt
|
||||
output_prompt = replace_with + " " + output_prompt
|
||||
|
||||
if num_instances > 1:
|
||||
print(
|
||||
f"Warning: {trigger} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.")
|
||||
|
||||
return output_prompt
|
||||
|
||||
def state_dict(self, vae=True, text_encoder=True, unet=True):
|
||||
state_dict = OrderedDict()
|
||||
if vae:
|
||||
|
||||
Reference in New Issue
Block a user