mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Added base setup for training t2i adapters. Currently untested, saw something else shiny i wanted to finish sirst. Added content_or_style to the training config. It defaults to balanced, which is standard uniform time step sampling. If style or content is passed, it will use cubic sampling for timesteps to favor timesteps that are beneficial for training them. for style, favor later timesteps. For content, favor earlier timesteps.
This commit is contained in:
@@ -1,17 +1,27 @@
|
||||
import os.path
|
||||
from collections import OrderedDict
|
||||
|
||||
from PIL import Image
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||
from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds
|
||||
from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork
|
||||
from toolkit.train_tools import get_torch_dtype, apply_snr_weight
|
||||
import gc
|
||||
import torch
|
||||
from jobs.process import BaseSDTrainProcess
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
def flush():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
adapter_transforms = transforms.Compose([
|
||||
transforms.PILToTensor(),
|
||||
])
|
||||
|
||||
|
||||
class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
@@ -31,11 +41,47 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.sd.vae.to('cpu')
|
||||
flush()
|
||||
|
||||
def get_adapter_images(self, batch: 'DataLoaderBatchDTO'):
|
||||
img_ext_list = ['.jpg', '.jpeg', '.png', '.webp']
|
||||
adapter_folder_path = self.adapter_config.image_dir
|
||||
adapter_images = []
|
||||
# loop through images
|
||||
for file_item in batch.file_items:
|
||||
img_path = file_item.path
|
||||
file_name_no_ext = os.path.basename(img_path).split('.')[0]
|
||||
# find the image
|
||||
for ext in img_ext_list:
|
||||
if os.path.exists(os.path.join(adapter_folder_path, file_name_no_ext + ext)):
|
||||
adapter_images.append(os.path.join(adapter_folder_path, file_name_no_ext + ext))
|
||||
break
|
||||
|
||||
adapter_tensors = []
|
||||
# load images with torch transforms
|
||||
for adapter_image in adapter_images:
|
||||
img = Image.open(adapter_image)
|
||||
img = adapter_transforms(img)
|
||||
adapter_tensors.append(img)
|
||||
|
||||
# stack them
|
||||
adapter_tensors = torch.stack(adapter_tensors)
|
||||
return adapter_tensors
|
||||
|
||||
def hook_train_loop(self, batch):
|
||||
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
|
||||
network_weight_list = batch.get_network_weight_list()
|
||||
|
||||
adapter_images = None
|
||||
sigmas = None
|
||||
if self.adapter:
|
||||
# todo move this to data loader
|
||||
adapter_images = self.get_adapter_images(batch)
|
||||
# not 100% sure what this does. But they do it here
|
||||
# https://github.com/huggingface/diffusers/blob/38a664a3d61e27ab18cd698231422b3c38d6eebf/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1170
|
||||
sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype)
|
||||
noisy_latents = noisy_latents / ((sigmas ** 2 + 1) ** 0.5)
|
||||
|
||||
# flush()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
@@ -64,30 +110,55 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# detach the embeddings
|
||||
conditional_embeds = conditional_embeds.detach()
|
||||
# flush()
|
||||
pred_kwargs = {}
|
||||
if self.adapter:
|
||||
down_block_additional_residuals = self.adapter(adapter_images)
|
||||
down_block_additional_residuals = [
|
||||
sample.to(dtype=dtype) for sample in down_block_additional_residuals
|
||||
]
|
||||
pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals
|
||||
|
||||
noise_pred = self.sd.predict_noise(
|
||||
latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs
|
||||
)
|
||||
# flush()
|
||||
# 9.18 gb
|
||||
noise = noise.to(self.device_torch, dtype=dtype).detach()
|
||||
|
||||
if self.adapter:
|
||||
# todo, diffusers does this on t2i training, is it better approach?
|
||||
# Denoise the latents
|
||||
denoised_latents = noise_pred * (-sigmas) + noisy_latents
|
||||
weighing = sigmas ** -2.0
|
||||
|
||||
if self.sd.prediction_type == 'v_prediction':
|
||||
# v-parameterization training
|
||||
target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
|
||||
# Get the target for loss depending on the prediction type
|
||||
if self.sd.noise_scheduler.config.prediction_type == "epsilon":
|
||||
target = batch.latents # we are computing loss against denoise latents
|
||||
elif self.sd.noise_scheduler.config.prediction_type == "v_prediction":
|
||||
target = self.sd.noise_scheduler.get_velocity(batch.latents, noise, timesteps)
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {self.sd.noise_scheduler.config.prediction_type}")
|
||||
|
||||
# MSE loss
|
||||
loss = torch.mean(
|
||||
(weighing.float() * (denoised_latents.float() - target.float()) ** 2).reshape(target.shape[0], -1),
|
||||
dim=1,
|
||||
)
|
||||
else:
|
||||
target = noise
|
||||
noise = noise.to(self.device_torch, dtype=dtype).detach()
|
||||
if self.sd.prediction_type == 'v_prediction':
|
||||
# v-parameterization training
|
||||
target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
|
||||
else:
|
||||
target = noise
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
|
||||
# add min_snr_gamma
|
||||
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
|
||||
# TODO: I think the sigma method does not need this. Check
|
||||
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
|
||||
# add min_snr_gamma
|
||||
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
|
||||
|
||||
loss = loss.mean()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user