Added base setup for training t2i adapters. Currently untested, saw something else shiny i wanted to finish sirst. Added content_or_style to the training config. It defaults to balanced, which is standard uniform time step sampling. If style or content is passed, it will use cubic sampling for timesteps to favor timesteps that are beneficial for training them. for style, favor later timesteps. For content, favor earlier timesteps.

This commit is contained in:
Jaret Burkett
2023-09-16 08:30:38 -06:00
parent 17e4fe40d7
commit 27f343fc08
8 changed files with 314 additions and 84 deletions

View File

@@ -1,17 +1,27 @@
import os.path
from collections import OrderedDict
from PIL import Image
from torch.utils.data import DataLoader
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds
from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork
from toolkit.train_tools import get_torch_dtype, apply_snr_weight
import gc
import torch
from jobs.process import BaseSDTrainProcess
from torchvision import transforms
def flush():
torch.cuda.empty_cache()
gc.collect()
adapter_transforms = transforms.Compose([
transforms.PILToTensor(),
])
class SDTrainer(BaseSDTrainProcess):
@@ -31,11 +41,47 @@ class SDTrainer(BaseSDTrainProcess):
self.sd.vae.to('cpu')
flush()
def get_adapter_images(self, batch: 'DataLoaderBatchDTO'):
img_ext_list = ['.jpg', '.jpeg', '.png', '.webp']
adapter_folder_path = self.adapter_config.image_dir
adapter_images = []
# loop through images
for file_item in batch.file_items:
img_path = file_item.path
file_name_no_ext = os.path.basename(img_path).split('.')[0]
# find the image
for ext in img_ext_list:
if os.path.exists(os.path.join(adapter_folder_path, file_name_no_ext + ext)):
adapter_images.append(os.path.join(adapter_folder_path, file_name_no_ext + ext))
break
adapter_tensors = []
# load images with torch transforms
for adapter_image in adapter_images:
img = Image.open(adapter_image)
img = adapter_transforms(img)
adapter_tensors.append(img)
# stack them
adapter_tensors = torch.stack(adapter_tensors)
return adapter_tensors
def hook_train_loop(self, batch):
dtype = get_torch_dtype(self.train_config.dtype)
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
network_weight_list = batch.get_network_weight_list()
adapter_images = None
sigmas = None
if self.adapter:
# todo move this to data loader
adapter_images = self.get_adapter_images(batch)
# not 100% sure what this does. But they do it here
# https://github.com/huggingface/diffusers/blob/38a664a3d61e27ab18cd698231422b3c38d6eebf/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1170
sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype)
noisy_latents = noisy_latents / ((sigmas ** 2 + 1) ** 0.5)
# flush()
self.optimizer.zero_grad()
@@ -64,30 +110,55 @@ class SDTrainer(BaseSDTrainProcess):
# detach the embeddings
conditional_embeds = conditional_embeds.detach()
# flush()
pred_kwargs = {}
if self.adapter:
down_block_additional_residuals = self.adapter(adapter_images)
down_block_additional_residuals = [
sample.to(dtype=dtype) for sample in down_block_additional_residuals
]
pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals
noise_pred = self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype),
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
timestep=timesteps,
guidance_scale=1.0,
**pred_kwargs
)
# flush()
# 9.18 gb
noise = noise.to(self.device_torch, dtype=dtype).detach()
if self.adapter:
# todo, diffusers does this on t2i training, is it better approach?
# Denoise the latents
denoised_latents = noise_pred * (-sigmas) + noisy_latents
weighing = sigmas ** -2.0
if self.sd.prediction_type == 'v_prediction':
# v-parameterization training
target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
# Get the target for loss depending on the prediction type
if self.sd.noise_scheduler.config.prediction_type == "epsilon":
target = batch.latents # we are computing loss against denoise latents
elif self.sd.noise_scheduler.config.prediction_type == "v_prediction":
target = self.sd.noise_scheduler.get_velocity(batch.latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {self.sd.noise_scheduler.config.prediction_type}")
# MSE loss
loss = torch.mean(
(weighing.float() * (denoised_latents.float() - target.float()) ** 2).reshape(target.shape[0], -1),
dim=1,
)
else:
target = noise
noise = noise.to(self.device_torch, dtype=dtype).detach()
if self.sd.prediction_type == 'v_prediction':
# v-parameterization training
target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
# add min_snr_gamma
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
# TODO: I think the sigma method does not need this. Check
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
# add min_snr_gamma
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
loss = loss.mean()