Files
ai-toolkit/toolkit/models/unified_training_model.py
2024-08-29 16:04:20 -06:00

1437 lines
71 KiB
Python

import random
from typing import Union, Optional
import torch
from diffusers import T2IAdapter, ControlNetModel
from safetensors.torch import load_file
from torch import nn
from toolkit.basic import value_map
from toolkit.clip_vision_adapter import ClipVisionAdapter
from toolkit.custom_adapter import CustomAdapter
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
from toolkit.embedding import Embedding
from toolkit.guidance import GuidanceType, get_guidance_loss
from toolkit.image_utils import reduce_contrast
from toolkit.ip_adapter import IPAdapter
from toolkit.network_mixins import Network
from toolkit.prompt_utils import PromptEmbeds
from toolkit.reference_adapter import ReferenceAdapter
from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork
from toolkit.timer import Timer, DummyTimer
from toolkit.train_tools import get_torch_dtype, apply_learnable_snr_gos, apply_snr_weight
from toolkit.config_modules import TrainConfig, AdapterConfig
AdapterType = Union[
T2IAdapter, IPAdapter, ClipVisionAdapter, ReferenceAdapter, CustomAdapter, ControlNetModel
]
class UnifiedTrainingModel(nn.Module):
def __init__(
self,
sd: StableDiffusion,
network: Optional[Network] = None,
adapter: Optional[AdapterType] = None,
assistant_adapter: Optional[AdapterType] = None,
train_config: TrainConfig = None,
adapter_config: AdapterConfig = None,
embedding: Optional[Embedding] = None,
timer: Timer = None,
trigger_word: Optional[str] = None,
gpu_ids: Optional[Union[int, list]] = None,
):
super(UnifiedTrainingModel, self).__init__()
self.sd: StableDiffusion = sd
self.network: Optional[Network] = network
self.adapter: Optional[AdapterType] = adapter
self.assistant_adapter: Union['T2IAdapter', 'ControlNetModel', None] = assistant_adapter
self.train_config: TrainConfig = train_config
self.adapter_config: AdapterConfig = adapter_config
self.embedding: Optional[Embedding] = embedding
self.timer: Timer = timer
self.trigger_word: Optional[str] = trigger_word
self.device_torch = torch.device("cuda")
self.gpu_ids = gpu_ids
self.primary_gpu_id = self.gpu_ids[0] # The first in the list is primary
# misc config
self.do_long_prompts = False
self.do_prior_prediction = False
self.do_guided_loss = False
if self.train_config.do_prior_divergence:
self.do_prior_prediction = True
# register modules from sd
self.text_encoders = nn.ModuleList([self.sd.text_encoder] if not isinstance(self.sd.text_encoder, list) else self.sd.text_encoder)
self.unet = self.sd.unet
self.vae = self.sd.vae
def is_primary_gpu(self):
return torch.cuda.current_device() == self.primary_gpu_id
def before_unet_predict(self):
pass
def after_unet_predict(self):
pass
def get_prior_prediction(
self,
noisy_latents: torch.Tensor,
conditional_embeds: PromptEmbeds,
match_adapter_assist: bool,
network_weight_list: list,
timesteps: torch.Tensor,
pred_kwargs: dict,
batch: 'DataLoaderBatchDTO',
noise: torch.Tensor,
unconditional_embeds: Optional[PromptEmbeds] = None,
conditioned_prompts=None,
**kwargs
):
# todo for embeddings, we need to run without trigger words
was_unet_training = self.sd.unet.training
was_network_active = False
if self.network is not None:
was_network_active = self.network.is_active
self.network.is_active = False
can_disable_adapter = False
was_adapter_active = False
if self.adapter is not None and (isinstance(self.adapter, IPAdapter) or
isinstance(self.adapter, ReferenceAdapter) or
(isinstance(self.adapter, CustomAdapter))
):
can_disable_adapter = True
was_adapter_active = self.adapter.is_active
self.adapter.is_active = False
# do a prediction here so we can match its output with network multiplier set to 0.0
with torch.no_grad():
dtype = get_torch_dtype(self.train_config.dtype)
embeds_to_use = conditional_embeds.clone().detach()
# handle clip vision adapter by removing triggers from prompt and replacing with the class name
if (self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter)) or self.embedding is not None:
prompt_list = batch.get_caption_list()
class_name = ''
triggers = ['[trigger]', '[name]']
remove_tokens = []
if self.embed_config is not None:
triggers.append(self.embed_config.trigger)
for i in range(1, self.embed_config.tokens):
remove_tokens.append(f"{self.embed_config.trigger}_{i}")
if self.embed_config.trigger_class_name is not None:
class_name = self.embed_config.trigger_class_name
if self.adapter is not None:
triggers.append(self.adapter_config.trigger)
for i in range(1, self.adapter_config.num_tokens):
remove_tokens.append(f"{self.adapter_config.trigger}_{i}")
if self.adapter_config.trigger_class_name is not None:
class_name = self.adapter_config.trigger_class_name
for idx, prompt in enumerate(prompt_list):
for remove_token in remove_tokens:
prompt = prompt.replace(remove_token, '')
for trigger in triggers:
prompt = prompt.replace(trigger, class_name)
prompt_list[idx] = prompt
embeds_to_use = self.sd.encode_prompt(
prompt_list,
long_prompts=self.do_long_prompts).to(
self.device_torch,
dtype=dtype).detach()
# dont use network on this
# self.network.multiplier = 0.0
self.sd.unet.eval()
if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not self.sd.is_flux:
# we need to remove the image embeds from the prompt except for flux
embeds_to_use: PromptEmbeds = embeds_to_use.clone().detach()
end_pos = embeds_to_use.text_embeds.shape[1] - self.adapter_config.num_tokens
embeds_to_use.text_embeds = embeds_to_use.text_embeds[:, :end_pos, :]
if unconditional_embeds is not None:
unconditional_embeds = unconditional_embeds.clone().detach()
unconditional_embeds.text_embeds = unconditional_embeds.text_embeds[:, :end_pos]
if unconditional_embeds is not None:
unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach()
prior_pred = self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(),
conditional_embeddings=embeds_to_use.to(self.device_torch, dtype=dtype).detach(),
unconditional_embeddings=unconditional_embeds,
timestep=timesteps,
guidance_scale=self.train_config.cfg_scale,
rescale_cfg=self.train_config.cfg_rescale,
**pred_kwargs # adapter residuals in here
)
if was_unet_training:
self.sd.unet.train()
prior_pred = prior_pred.detach()
# remove the residuals as we wont use them on prediction when matching control
if match_adapter_assist and 'down_intrablock_additional_residuals' in pred_kwargs:
del pred_kwargs['down_intrablock_additional_residuals']
if match_adapter_assist and 'down_block_additional_residuals' in pred_kwargs:
del pred_kwargs['down_block_additional_residuals']
if match_adapter_assist and 'mid_block_additional_residual' in pred_kwargs:
del pred_kwargs['mid_block_additional_residual']
if can_disable_adapter:
self.adapter.is_active = was_adapter_active
# restore network
# self.network.multiplier = network_weight_list
if self.network is not None:
self.network.is_active = was_network_active
return prior_pred
def predict_noise(
self,
noisy_latents: torch.Tensor,
timesteps: Union[int, torch.Tensor] = 1,
conditional_embeds: Union[PromptEmbeds, None] = None,
unconditional_embeds: Union[PromptEmbeds, None] = None,
**kwargs,
):
dtype = get_torch_dtype(self.train_config.dtype)
return self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype),
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
unconditional_embeddings=unconditional_embeds,
timestep=timesteps,
guidance_scale=self.train_config.cfg_scale,
detach_unconditional=False,
rescale_cfg=self.train_config.cfg_rescale,
**kwargs
)
def get_guided_loss(
self,
noisy_latents: torch.Tensor,
conditional_embeds: PromptEmbeds,
match_adapter_assist: bool,
network_weight_list: list,
timesteps: torch.Tensor,
pred_kwargs: dict,
batch: 'DataLoaderBatchDTO',
noise: torch.Tensor,
unconditional_embeds: Optional[PromptEmbeds] = None,
**kwargs
):
loss = get_guidance_loss(
noisy_latents=noisy_latents,
conditional_embeds=conditional_embeds,
match_adapter_assist=match_adapter_assist,
network_weight_list=network_weight_list,
timesteps=timesteps,
pred_kwargs=pred_kwargs,
batch=batch,
noise=noise,
sd=self.sd,
unconditional_embeds=unconditional_embeds,
scaler=self.scaler,
**kwargs
)
return loss
def get_noise(self, latents, batch_size, dtype=torch.float32):
# get noise
noise = self.sd.get_latent_noise(
height=latents.shape[2],
width=latents.shape[3],
batch_size=batch_size,
noise_offset=self.train_config.noise_offset,
).to(self.device_torch, dtype=dtype)
if self.train_config.random_noise_shift > 0.0:
# get random noise -1 to 1
noise_shift = torch.rand((noise.shape[0], noise.shape[1], 1, 1), device=noise.device,
dtype=noise.dtype) * 2 - 1
# multiply by shift amount
noise_shift *= self.train_config.random_noise_shift
# add to noise
noise += noise_shift
# standardize the noise
std = noise.std(dim=(2, 3), keepdim=True)
normalizer = 1 / (std + 1e-6)
noise = noise * normalizer
return noise
def calculate_loss(
self,
noise_pred: torch.Tensor,
noise: torch.Tensor,
noisy_latents: torch.Tensor,
timesteps: torch.Tensor,
batch: 'DataLoaderBatchDTO',
mask_multiplier: Union[torch.Tensor, float] = 1.0,
prior_pred: Union[torch.Tensor, None] = None,
**kwargs
):
loss_target = self.train_config.loss_target
is_reg = any(batch.get_is_reg_list())
prior_mask_multiplier = None
target_mask_multiplier = None
dtype = get_torch_dtype(self.train_config.dtype)
has_mask = batch.mask_tensor is not None
with torch.no_grad():
loss_multiplier = torch.tensor(batch.loss_multiplier_list).to(self.device_torch, dtype=torch.float32)
if self.train_config.match_noise_norm:
# match the norm of the noise
noise_norm = torch.linalg.vector_norm(noise, ord=2, dim=(1, 2, 3), keepdim=True)
noise_pred_norm = torch.linalg.vector_norm(noise_pred, ord=2, dim=(1, 2, 3), keepdim=True)
noise_pred = noise_pred * (noise_norm / noise_pred_norm)
if self.train_config.pred_scaler != 1.0:
noise_pred = noise_pred * self.train_config.pred_scaler
target = None
if self.train_config.target_noise_multiplier != 1.0:
noise = noise * self.train_config.target_noise_multiplier
if self.train_config.correct_pred_norm or (self.train_config.inverted_mask_prior and prior_pred is not None and has_mask):
if self.train_config.correct_pred_norm and not is_reg:
with torch.no_grad():
# this only works if doing a prior pred
if prior_pred is not None:
prior_mean = prior_pred.mean([2,3], keepdim=True)
prior_std = prior_pred.std([2,3], keepdim=True)
noise_mean = noise_pred.mean([2,3], keepdim=True)
noise_std = noise_pred.std([2,3], keepdim=True)
mean_adjust = prior_mean - noise_mean
std_adjust = prior_std - noise_std
mean_adjust = mean_adjust * self.train_config.correct_pred_norm_multiplier
std_adjust = std_adjust * self.train_config.correct_pred_norm_multiplier
target_mean = noise_mean + mean_adjust
target_std = noise_std + std_adjust
eps = 1e-5
# match the noise to the prior
noise = (noise - noise_mean) / (noise_std + eps)
noise = noise * (target_std + eps) + target_mean
noise = noise.detach()
if self.train_config.inverted_mask_prior and prior_pred is not None and has_mask:
assert not self.train_config.train_turbo
with torch.no_grad():
# we need to make the noise prediction be a masked blending of noise and prior_pred
stretched_mask_multiplier = value_map(
mask_multiplier,
batch.file_items[0].dataset_config.mask_min_value,
1.0,
0.0,
1.0
)
prior_mask_multiplier = 1.0 - stretched_mask_multiplier
# target_mask_multiplier = mask_multiplier
# mask_multiplier = 1.0
target = noise
# target = (noise * mask_multiplier) + (prior_pred * prior_mask_multiplier)
# set masked multiplier to 1.0 so we dont double apply it
# mask_multiplier = 1.0
elif prior_pred is not None and not self.train_config.do_prior_divergence:
assert not self.train_config.train_turbo
# matching adapter prediction
target = prior_pred
elif self.sd.prediction_type == 'v_prediction':
# v-parameterization training
target = self.sd.noise_scheduler.get_velocity(batch.tensor, noise, timesteps)
elif self.sd.is_flow_matching:
target = (noise - batch.latents).detach()
else:
target = noise
if target is None:
target = noise
pred = noise_pred
if self.train_config.train_turbo:
raise ValueError("Turbo training is not supported in MultiGPUSDTrainer")
ignore_snr = False
if loss_target == 'source' or loss_target == 'unaugmented':
assert not self.train_config.train_turbo
# ignore_snr = True
if batch.sigmas is None:
raise ValueError("Batch sigmas is None. This should not happen")
# src https://github.com/huggingface/diffusers/blob/324d18fba23f6c9d7475b0ff7c777685f7128d40/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1190
denoised_latents = noise_pred * (-batch.sigmas) + noisy_latents
weighing = batch.sigmas ** -2.0
if loss_target == 'source':
# denoise the latent and compare to the latent in the batch
target = batch.latents
elif loss_target == 'unaugmented':
# we have to encode images into latents for now
# we also denoise as the unaugmented tensor is not a noisy diffirental
with torch.no_grad():
unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor).to(self.device_torch, dtype=dtype)
unaugmented_latents = unaugmented_latents * self.train_config.latent_multiplier
target = unaugmented_latents.detach()
# Get the target for loss depending on the prediction type
if self.sd.noise_scheduler.config.prediction_type == "epsilon":
target = target # we are computing loss against denoise latents
elif self.sd.noise_scheduler.config.prediction_type == "v_prediction":
target = self.sd.noise_scheduler.get_velocity(target, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {self.sd.noise_scheduler.config.prediction_type}")
# mse loss without reduction
loss_per_element = (weighing.float() * (denoised_latents.float() - target.float()) ** 2)
loss = loss_per_element
else:
if self.train_config.loss_type == "mae":
loss = torch.nn.functional.l1_loss(pred.float(), target.float(), reduction="none")
else:
loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none")
# handle linear timesteps and only adjust the weight of the timesteps
if self.sd.is_flow_matching and self.train_config.linear_timesteps:
# calculate the weights for the timesteps
timestep_weight = self.sd.noise_scheduler.get_weights_for_timesteps(timesteps).to(loss.device, dtype=loss.dtype)
timestep_weight = timestep_weight.view(-1, 1, 1, 1).detach()
loss = loss * timestep_weight
if self.train_config.do_prior_divergence and prior_pred is not None:
loss = loss + (torch.nn.functional.mse_loss(pred.float(), prior_pred.float(), reduction="none") * -1.0)
if self.train_config.train_turbo:
mask_multiplier = mask_multiplier[:, 3:, :, :]
# resize to the size of the loss
mask_multiplier = torch.nn.functional.interpolate(mask_multiplier, size=(pred.shape[2], pred.shape[3]), mode='nearest')
# multiply by our mask
loss = loss * mask_multiplier
prior_loss = None
if self.train_config.inverted_mask_prior and prior_pred is not None and prior_mask_multiplier is not None:
assert not self.train_config.train_turbo
if self.train_config.loss_type == "mae":
prior_loss = torch.nn.functional.l1_loss(pred.float(), prior_pred.float(), reduction="none")
else:
prior_loss = torch.nn.functional.mse_loss(pred.float(), prior_pred.float(), reduction="none")
prior_loss = prior_loss * prior_mask_multiplier * self.train_config.inverted_mask_prior_multiplier
if torch.isnan(prior_loss).any():
print("Prior loss is nan")
prior_loss = None
else:
prior_loss = prior_loss.mean([1, 2, 3])
# loss = loss + prior_loss
# loss = loss + prior_loss
# loss = loss + prior_loss
loss = loss.mean([1, 2, 3])
# apply loss multiplier before prior loss
loss = loss * loss_multiplier
if prior_loss is not None:
loss = loss + prior_loss
if not self.train_config.train_turbo:
if self.train_config.learnable_snr_gos:
# add snr_gamma
loss = apply_learnable_snr_gos(loss, timesteps, self.snr_gos)
elif self.train_config.snr_gamma is not None and self.train_config.snr_gamma > 0.000001 and not ignore_snr:
# add snr_gamma
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.snr_gamma,
fixed=True)
elif self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001 and not ignore_snr:
# add min_snr_gamma
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
loss = loss.mean()
# check for additional losses
if self.adapter is not None and hasattr(self.adapter, "additional_loss") and self.adapter.additional_loss is not None:
loss = loss + self.adapter.additional_loss.mean()
self.adapter.additional_loss = None
if self.train_config.target_norm_std:
# seperate out the batch and channels
pred_std = noise_pred.std([2, 3], keepdim=True)
norm_std_loss = torch.abs(self.train_config.target_norm_std_value - pred_std).mean()
loss = loss + norm_std_loss
return loss
def preprocess_batch(self, batch: 'DataLoaderBatchDTO'):
return batch
def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'):
with torch.no_grad():
with self.timer('prepare_prompt'):
prompts = batch.get_caption_list()
is_reg_list = batch.get_is_reg_list()
is_any_reg = any([is_reg for is_reg in is_reg_list])
do_double = self.train_config.short_and_long_captions and not is_any_reg
if self.train_config.short_and_long_captions and do_double:
# dont do this with regs. No point
# double batch and add short captions to the end
prompts = prompts + batch.get_caption_short_list()
is_reg_list = is_reg_list + is_reg_list
if self.sd.model_config.refiner_name_or_path is not None and self.train_config.train_unet:
prompts = prompts + prompts
is_reg_list = is_reg_list + is_reg_list
conditioned_prompts = []
for prompt, is_reg in zip(prompts, is_reg_list):
# make sure the embedding is in the prompts
if self.embedding is not None:
prompt = self.embedding.inject_embedding_to_prompt(
prompt,
expand_token=True,
add_if_not_present=not is_reg,
)
if self.adapter and isinstance(self.adapter, ClipVisionAdapter):
prompt = self.adapter.inject_trigger_into_prompt(
prompt,
expand_token=True,
add_if_not_present=not is_reg,
)
# make sure trigger is in the prompts if not a regularization run
if self.trigger_word is not None:
prompt = self.sd.inject_trigger_into_prompt(
prompt,
trigger=self.trigger_word,
add_if_not_present=not is_reg,
)
if not is_reg and self.train_config.prompt_saturation_chance > 0.0:
# do random prompt saturation by expanding the prompt to hit at least 77 tokens
if random.random() < self.train_config.prompt_saturation_chance:
est_num_tokens = len(prompt.split(' '))
if est_num_tokens < 77:
num_repeats = int(77 / est_num_tokens) + 1
prompt = ', '.join([prompt] * num_repeats)
conditioned_prompts.append(prompt)
with self.timer('prepare_latents'):
dtype = get_torch_dtype(self.train_config.dtype)
imgs = None
is_reg = any(batch.get_is_reg_list())
if batch.tensor is not None:
imgs = batch.tensor
imgs = imgs.to(self.device_torch, dtype=dtype)
# dont adjust for regs.
if self.train_config.img_multiplier is not None and not is_reg:
# do it ad contrast
imgs = reduce_contrast(imgs, self.train_config.img_multiplier)
if batch.latents is not None:
latents = batch.latents.to(self.device_torch, dtype=dtype)
batch.latents = latents
else:
# normalize to
if self.train_config.standardize_images:
if self.sd.is_xl or self.sd.is_vega or self.sd.is_ssd:
target_mean_list = [0.0002, -0.1034, -0.1879]
target_std_list = [0.5436, 0.5116, 0.5033]
else:
target_mean_list = [-0.0739, -0.1597, -0.2380]
target_std_list = [0.5623, 0.5295, 0.5347]
# Mean: tensor([-0.0739, -0.1597, -0.2380])
# Standard Deviation: tensor([0.5623, 0.5295, 0.5347])
imgs_channel_mean = imgs.mean(dim=(2, 3), keepdim=True)
imgs_channel_std = imgs.std(dim=(2, 3), keepdim=True)
imgs = (imgs - imgs_channel_mean) / imgs_channel_std
target_mean = torch.tensor(target_mean_list, device=self.device_torch, dtype=dtype)
target_std = torch.tensor(target_std_list, device=self.device_torch, dtype=dtype)
# expand them to match dim
target_mean = target_mean.unsqueeze(0).unsqueeze(2).unsqueeze(3)
target_std = target_std.unsqueeze(0).unsqueeze(2).unsqueeze(3)
imgs = imgs * target_std + target_mean
batch.tensor = imgs
# show_tensors(imgs, 'imgs')
latents = self.sd.encode_images(imgs)
batch.latents = latents
if self.train_config.standardize_latents:
if self.sd.is_xl or self.sd.is_vega or self.sd.is_ssd:
target_mean_list = [-0.1075, 0.0231, -0.0135, 0.2164]
target_std_list = [0.8979, 0.7505, 0.9150, 0.7451]
else:
target_mean_list = [0.2949, -0.3188, 0.0807, 0.1929]
target_std_list = [0.8560, 0.9629, 0.7778, 0.6719]
latents_channel_mean = latents.mean(dim=(2, 3), keepdim=True)
latents_channel_std = latents.std(dim=(2, 3), keepdim=True)
latents = (latents - latents_channel_mean) / latents_channel_std
target_mean = torch.tensor(target_mean_list, device=self.device_torch, dtype=dtype)
target_std = torch.tensor(target_std_list, device=self.device_torch, dtype=dtype)
# expand them to match dim
target_mean = target_mean.unsqueeze(0).unsqueeze(2).unsqueeze(3)
target_std = target_std.unsqueeze(0).unsqueeze(2).unsqueeze(3)
latents = latents * target_std + target_mean
batch.latents = latents
# show_latents(latents, self.sd.vae, 'latents')
if batch.unconditional_tensor is not None and batch.unconditional_latents is None:
unconditional_imgs = batch.unconditional_tensor
unconditional_imgs = unconditional_imgs.to(self.device_torch, dtype=dtype)
unconditional_latents = self.sd.encode_images(unconditional_imgs)
batch.unconditional_latents = unconditional_latents * self.train_config.latent_multiplier
unaugmented_latents = None
if self.train_config.loss_target == 'differential_noise':
# we determine noise from the differential of the latents
unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor)
batch_size = len(batch.file_items)
min_noise_steps = self.train_config.min_denoising_steps
max_noise_steps = self.train_config.max_denoising_steps
if self.sd.model_config.refiner_name_or_path is not None:
# if we are not training the unet, then we are only doing refiner and do not need to double up
if self.train_config.train_unet:
max_noise_steps = round(self.train_config.max_denoising_steps * self.sd.model_config.refiner_start_at)
do_double = True
else:
min_noise_steps = round(self.train_config.max_denoising_steps * self.sd.model_config.refiner_start_at)
do_double = False
with self.timer('prepare_noise'):
num_train_timesteps = self.train_config.num_train_timesteps
if self.train_config.noise_scheduler in ['custom_lcm']:
# we store this value on our custom one
self.sd.noise_scheduler.set_timesteps(
self.sd.noise_scheduler.train_timesteps, device=self.device_torch
)
elif self.train_config.noise_scheduler in ['lcm']:
self.sd.noise_scheduler.set_timesteps(
num_train_timesteps, device=self.device_torch, original_inference_steps=num_train_timesteps
)
elif self.train_config.noise_scheduler == 'flowmatch':
self.sd.noise_scheduler.set_train_timesteps(
num_train_timesteps,
device=self.device_torch,
linear=self.train_config.linear_timesteps
)
else:
self.sd.noise_scheduler.set_timesteps(
num_train_timesteps, device=self.device_torch
)
content_or_style = self.train_config.content_or_style
if is_reg:
content_or_style = self.train_config.content_or_style_reg
# if self.train_config.timestep_sampling == 'style' or self.train_config.timestep_sampling == 'content':
if content_or_style in ['style', 'content']:
# this is from diffusers training code
# Cubic sampling for favoring later or earlier timesteps
# For more details about why cubic sampling is used for content / structure,
# refer to section 3.4 of https://arxiv.org/abs/2302.08453
# for content / structure, it is best to favor earlier timesteps
# for style, it is best to favor later timesteps
orig_timesteps = torch.rand((batch_size,), device=latents.device)
if content_or_style == 'content':
timestep_indices = orig_timesteps ** 3 * self.train_config.num_train_timesteps
elif content_or_style == 'style':
timestep_indices = (1 - orig_timesteps ** 3) * self.train_config.num_train_timesteps
timestep_indices = value_map(
timestep_indices,
0,
self.train_config.num_train_timesteps - 1,
min_noise_steps,
max_noise_steps - 1
)
timestep_indices = timestep_indices.long().clamp(
min_noise_steps + 1,
max_noise_steps - 1
)
elif content_or_style == 'balanced':
if min_noise_steps == max_noise_steps:
timestep_indices = torch.ones((batch_size,), device=self.device_torch) * min_noise_steps
else:
# todo, some schedulers use indices, otheres use timesteps. Not sure what to do here
timestep_indices = torch.randint(
min_noise_steps + 1,
max_noise_steps - 1,
(batch_size,),
device=self.device_torch
)
timestep_indices = timestep_indices.long()
else:
raise ValueError(f"Unknown content_or_style {content_or_style}")
# do flow matching
# if self.sd.is_flow_matching:
# u = compute_density_for_timestep_sampling(
# weighting_scheme="logit_normal", # ["sigma_sqrt", "logit_normal", "mode", "cosmap"]
# batch_size=batch_size,
# logit_mean=0.0,
# logit_std=1.0,
# mode_scale=1.29,
# )
# timestep_indices = (u * self.sd.noise_scheduler.config.num_train_timesteps).long()
# convert the timestep_indices to a timestep
timesteps = [self.sd.noise_scheduler.timesteps[x.item()] for x in timestep_indices]
timesteps = torch.stack(timesteps, dim=0)
# get noise
noise = self.get_noise(latents, batch_size, dtype=dtype)
# add dynamic noise offset. Dynamic noise is offsetting the noise to the same channelwise mean as the latents
# this will negate any noise offsets
if self.train_config.dynamic_noise_offset and not is_reg:
latents_channel_mean = latents.mean(dim=(2, 3), keepdim=True) / 2
# subtract channel mean to that we compensate for the mean of the latents on the noise offset per channel
noise = noise + latents_channel_mean
if self.train_config.loss_target == 'differential_noise':
differential = latents - unaugmented_latents
# add noise to differential
# noise = noise + differential
noise = noise + (differential * 0.5)
# noise = value_map(differential, 0, torch.abs(differential).max(), 0, torch.abs(noise).max())
latents = unaugmented_latents
noise_multiplier = self.train_config.noise_multiplier
noise = noise * noise_multiplier
latent_multiplier = self.train_config.latent_multiplier
# handle adaptive scaling mased on std
if self.train_config.adaptive_scaling_factor:
std = latents.std(dim=(2, 3), keepdim=True)
normalizer = 1 / (std + 1e-6)
latent_multiplier = normalizer
latents = latents * latent_multiplier
batch.latents = latents
# normalize latents to a mean of 0 and an std of 1
# mean_zero_latents = latents - latents.mean()
# latents = mean_zero_latents / mean_zero_latents.std()
if batch.unconditional_latents is not None:
batch.unconditional_latents = batch.unconditional_latents * self.train_config.latent_multiplier
noisy_latents = self.sd.add_noise(latents, noise, timesteps)
# determine scaled noise
# todo do we need to scale this or does it always predict full intensity
# noise = noisy_latents - latents
# https://github.com/huggingface/diffusers/blob/324d18fba23f6c9d7475b0ff7c777685f7128d40/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1170C17-L1171C77
if self.train_config.loss_target == 'source' or self.train_config.loss_target == 'unaugmented':
sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype)
# add it to the batch
batch.sigmas = sigmas
# todo is this for sdxl? find out where this came from originally
# noisy_latents = noisy_latents / ((sigmas ** 2 + 1) ** 0.5)
def double_up_tensor(tensor: torch.Tensor):
if tensor is None:
return None
return torch.cat([tensor, tensor], dim=0)
if do_double:
if self.sd.model_config.refiner_name_or_path:
# apply refiner double up
refiner_timesteps = torch.randint(
max_noise_steps,
self.train_config.max_denoising_steps,
(batch_size,),
device=self.device_torch
)
refiner_timesteps = refiner_timesteps.long()
# add our new timesteps on to end
timesteps = torch.cat([timesteps, refiner_timesteps], dim=0)
refiner_noisy_latents = self.sd.noise_scheduler.add_noise(latents, noise, refiner_timesteps)
noisy_latents = torch.cat([noisy_latents, refiner_noisy_latents], dim=0)
else:
# just double it
noisy_latents = double_up_tensor(noisy_latents)
timesteps = double_up_tensor(timesteps)
noise = double_up_tensor(noise)
# prompts are already updated above
imgs = double_up_tensor(imgs)
batch.mask_tensor = double_up_tensor(batch.mask_tensor)
batch.control_tensor = double_up_tensor(batch.control_tensor)
noisy_latent_multiplier = self.train_config.noisy_latent_multiplier
if noisy_latent_multiplier != 1.0:
noisy_latents = noisy_latents * noisy_latent_multiplier
# remove grads for these
noisy_latents.requires_grad = False
noisy_latents = noisy_latents.detach()
noise.requires_grad = False
noise = noise.detach()
return noisy_latents, noise, timesteps, conditioned_prompts, imgs
def forward(self, batch: DataLoaderBatchDTO):
if not self.is_primary_gpu():
# replace timer with dummy one
self.timer = DummyTimer()
self.device_torch = torch.cuda.current_device()
self.timer.start('preprocess_batch')
batch = self.preprocess_batch(batch)
dtype = get_torch_dtype(self.train_config.dtype)
# sanity check
if self.sd.vae.dtype != self.sd.vae_torch_dtype:
self.sd.vae = self.sd.vae.to(self.sd.vae_torch_dtype)
if isinstance(self.sd.text_encoder, list):
for encoder in self.sd.text_encoder:
if encoder.dtype != self.sd.te_torch_dtype:
encoder.to(self.sd.te_torch_dtype)
else:
if self.sd.text_encoder.dtype != self.sd.te_torch_dtype:
self.sd.text_encoder.to(self.sd.te_torch_dtype)
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
if self.train_config.do_cfg or self.train_config.do_random_cfg:
# pick random negative prompts
if self.negative_prompt_pool is not None:
negative_prompts = []
for i in range(noisy_latents.shape[0]):
num_neg = random.randint(1, self.train_config.max_negative_prompts)
this_neg_prompts = [random.choice(self.negative_prompt_pool) for _ in range(num_neg)]
this_neg_prompt = ', '.join(this_neg_prompts)
negative_prompts.append(this_neg_prompt)
self.batch_negative_prompt = negative_prompts
else:
self.batch_negative_prompt = ['' for _ in range(batch.latents.shape[0])]
if self.adapter and isinstance(self.adapter, CustomAdapter):
# condition the prompt
# todo handle more than one adapter image
self.adapter.num_control_images = 1
conditioned_prompts = self.adapter.condition_prompt(conditioned_prompts)
network_weight_list = batch.get_network_weight_list()
if self.train_config.single_item_batching:
network_weight_list = network_weight_list + network_weight_list
has_adapter_img = batch.control_tensor is not None
has_clip_image = batch.clip_image_tensor is not None
has_clip_image_embeds = batch.clip_image_embeds is not None
# force it to be true if doing regs as we handle those differently
if any([batch.file_items[idx].is_reg for idx in range(len(batch.file_items))]):
has_clip_image = True
if self._clip_image_embeds_unconditional is not None:
has_clip_image_embeds = True # we are caching embeds, handle that differently
has_clip_image = False
if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not has_clip_image and has_adapter_img:
raise ValueError(
"IPAdapter control image is now 'clip_image_path' instead of 'control_path'. Please update your dataset config ")
match_adapter_assist = False
# check if we are matching the adapter assistant
if self.assistant_adapter:
if self.train_config.match_adapter_chance == 1.0:
match_adapter_assist = True
elif self.train_config.match_adapter_chance > 0.0:
match_adapter_assist = torch.rand(
(1,), device=self.device_torch, dtype=dtype
) < self.train_config.match_adapter_chance
self.timer.stop('preprocess_batch')
is_reg = False
with torch.no_grad():
loss_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype)
for idx, file_item in enumerate(batch.file_items):
if file_item.is_reg:
loss_multiplier[idx] = loss_multiplier[idx] * self.train_config.reg_weight
is_reg = True
adapter_images = None
sigmas = None
if has_adapter_img and (self.adapter or self.assistant_adapter):
with self.timer('get_adapter_images'):
# todo move this to data loader
if batch.control_tensor is not None:
adapter_images = batch.control_tensor.to(self.device_torch, dtype=dtype).detach()
# match in channels
if self.assistant_adapter is not None:
in_channels = self.assistant_adapter.config.in_channels
if adapter_images.shape[1] != in_channels:
# we need to match the channels
adapter_images = adapter_images[:, :in_channels, :, :]
else:
raise NotImplementedError("Adapter images now must be loaded with dataloader")
clip_images = None
if has_clip_image:
with self.timer('get_clip_images'):
# todo move this to data loader
if batch.clip_image_tensor is not None:
clip_images = batch.clip_image_tensor.to(self.device_torch, dtype=dtype).detach()
mask_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype)
if batch.mask_tensor is not None:
with self.timer('get_mask_multiplier'):
# upsampling no supported for bfloat16
mask_multiplier = batch.mask_tensor.to(self.device_torch, dtype=torch.float16).detach()
# scale down to the size of the latents, mask multiplier shape(bs, 1, width, height), noisy_latents shape(bs, channels, width, height)
mask_multiplier = torch.nn.functional.interpolate(
mask_multiplier, size=(noisy_latents.shape[2], noisy_latents.shape[3])
)
# expand to match latents
mask_multiplier = mask_multiplier.expand(-1, noisy_latents.shape[1], -1, -1)
mask_multiplier = mask_multiplier.to(self.device_torch, dtype=dtype).detach()
def get_adapter_multiplier():
if self.adapter and isinstance(self.adapter, T2IAdapter):
# training a t2i adapter, not using as assistant.
return 1.0
elif match_adapter_assist:
# training a texture. We want it high
adapter_strength_min = 0.9
adapter_strength_max = 1.0
else:
# training with assistance, we want it low
# adapter_strength_min = 0.4
# adapter_strength_max = 0.7
adapter_strength_min = 0.5
adapter_strength_max = 1.1
adapter_conditioning_scale = torch.rand(
(1,), device=self.device_torch, dtype=dtype
)
adapter_conditioning_scale = value_map(
adapter_conditioning_scale,
0.0,
1.0,
adapter_strength_min,
adapter_strength_max
)
return adapter_conditioning_scale
# flush()
with self.timer('grad_setup'):
# text encoding
grad_on_text_encoder = False
if self.train_config.train_text_encoder:
grad_on_text_encoder = True
if self.embedding is not None:
grad_on_text_encoder = True
if self.adapter and isinstance(self.adapter, ClipVisionAdapter):
grad_on_text_encoder = True
if self.adapter_config and self.adapter_config.type == 'te_augmenter':
grad_on_text_encoder = True
# have a blank network so we can wrap it in a context and set multipliers without checking every time
if self.network is not None:
network = self.network
else:
network = BlankNetwork()
# set the weights
network.multiplier = network_weight_list
# activate network if it exits
prompts_1 = conditioned_prompts
prompts_2 = None
if self.train_config.short_and_long_captions_encoder_split and self.sd.is_xl:
prompts_1 = batch.get_caption_short_list()
prompts_2 = conditioned_prompts
# make the batch splits
if self.train_config.single_item_batching:
if self.sd.model_config.refiner_name_or_path is not None:
raise ValueError("Single item batching is not supported when training the refiner")
batch_size = noisy_latents.shape[0]
# chunk/split everything
noisy_latents_list = torch.chunk(noisy_latents, batch_size, dim=0)
noise_list = torch.chunk(noise, batch_size, dim=0)
timesteps_list = torch.chunk(timesteps, batch_size, dim=0)
conditioned_prompts_list = [[prompt] for prompt in prompts_1]
if imgs is not None:
imgs_list = torch.chunk(imgs, batch_size, dim=0)
else:
imgs_list = [None for _ in range(batch_size)]
if adapter_images is not None:
adapter_images_list = torch.chunk(adapter_images, batch_size, dim=0)
else:
adapter_images_list = [None for _ in range(batch_size)]
if clip_images is not None:
clip_images_list = torch.chunk(clip_images, batch_size, dim=0)
else:
clip_images_list = [None for _ in range(batch_size)]
mask_multiplier_list = torch.chunk(mask_multiplier, batch_size, dim=0)
if prompts_2 is None:
prompt_2_list = [None for _ in range(batch_size)]
else:
prompt_2_list = [[prompt] for prompt in prompts_2]
else:
noisy_latents_list = [noisy_latents]
noise_list = [noise]
timesteps_list = [timesteps]
conditioned_prompts_list = [prompts_1]
imgs_list = [imgs]
adapter_images_list = [adapter_images]
clip_images_list = [clip_images]
mask_multiplier_list = [mask_multiplier]
if prompts_2 is None:
prompt_2_list = [None]
else:
prompt_2_list = [prompts_2]
for noisy_latents, noise, timesteps, conditioned_prompts, imgs, adapter_images, clip_images, mask_multiplier, prompt_2 in zip(
noisy_latents_list,
noise_list,
timesteps_list,
conditioned_prompts_list,
imgs_list,
adapter_images_list,
clip_images_list,
mask_multiplier_list,
prompt_2_list
):
with (network):
# encode clip adapter here so embeds are active for tokenizer
if self.adapter and isinstance(self.adapter, ClipVisionAdapter):
with self.timer('encode_clip_vision_embeds'):
if has_clip_image:
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
clip_images.detach().to(self.device_torch, dtype=dtype),
is_training=True,
has_been_preprocessed=True
)
else:
# just do a blank one
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
torch.zeros(
(noisy_latents.shape[0], 3, 512, 512),
device=self.device_torch, dtype=dtype
),
is_training=True,
has_been_preprocessed=True,
drop=True
)
# it will be injected into the tokenizer when called
self.adapter(conditional_clip_embeds)
# do the custom adapter after the prior prediction
if self.adapter and isinstance(self.adapter, CustomAdapter) and has_clip_image:
quad_count = random.randint(1, 4)
self.adapter.train()
self.adapter.trigger_pre_te(
tensors_0_1=clip_images if not is_reg else None, # on regs we send none to get random noise
is_training=True,
has_been_preprocessed=True,
quad_count=quad_count,
batch_size=noisy_latents.shape[0]
)
with self.timer('encode_prompt'):
unconditional_embeds = None
if grad_on_text_encoder:
with torch.set_grad_enabled(True):
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = False
conditional_embeds = self.sd.encode_prompt(
conditioned_prompts, prompt_2,
dropout_prob=self.train_config.prompt_dropout_prob,
long_prompts=self.do_long_prompts).to(
self.device_torch,
dtype=dtype)
if self.train_config.do_cfg:
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = True
# todo only do one and repeat it
unconditional_embeds = self.sd.encode_prompt(
self.batch_negative_prompt,
self.batch_negative_prompt,
dropout_prob=self.train_config.prompt_dropout_prob,
long_prompts=self.do_long_prompts).to(
self.device_torch,
dtype=dtype)
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = False
else:
with torch.set_grad_enabled(False):
# make sure it is in eval mode
if isinstance(self.sd.text_encoder, list):
for te in self.sd.text_encoder:
te.eval()
else:
self.sd.text_encoder.eval()
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = False
conditional_embeds = self.sd.encode_prompt(
conditioned_prompts, prompt_2,
dropout_prob=self.train_config.prompt_dropout_prob,
long_prompts=self.do_long_prompts).to(
self.device_torch,
dtype=dtype)
if self.train_config.do_cfg:
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = True
unconditional_embeds = self.sd.encode_prompt(
self.batch_negative_prompt,
dropout_prob=self.train_config.prompt_dropout_prob,
long_prompts=self.do_long_prompts).to(
self.device_torch,
dtype=dtype)
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = False
# detach the embeddings
conditional_embeds = conditional_embeds.detach()
if self.train_config.do_cfg:
unconditional_embeds = unconditional_embeds.detach()
# flush()
pred_kwargs = {}
if has_adapter_img:
if (self.adapter and isinstance(self.adapter, T2IAdapter)) or (
self.assistant_adapter and isinstance(self.assistant_adapter, T2IAdapter)):
with torch.set_grad_enabled(self.adapter is not None):
adapter = self.assistant_adapter if self.assistant_adapter is not None else self.adapter
adapter_multiplier = get_adapter_multiplier()
with self.timer('encode_adapter'):
down_block_additional_residuals = adapter(adapter_images)
if self.assistant_adapter:
# not training. detach
down_block_additional_residuals = [
sample.to(dtype=dtype).detach() * adapter_multiplier for sample in
down_block_additional_residuals
]
else:
down_block_additional_residuals = [
sample.to(dtype=dtype) * adapter_multiplier for sample in
down_block_additional_residuals
]
pred_kwargs['down_intrablock_additional_residuals'] = down_block_additional_residuals
if self.adapter and isinstance(self.adapter, IPAdapter):
with self.timer('encode_adapter_embeds'):
# number of images to do if doing a quad image
quad_count = random.randint(1, 4)
image_size = self.adapter.input_size
if has_clip_image_embeds:
# todo handle reg images better than this
if is_reg:
# get unconditional image embeds from cache
embeds = [
load_file(random.choice(batch.clip_image_embeds_unconditional)) for i in
range(noisy_latents.shape[0])
]
conditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache(
embeds,
quad_count=quad_count
)
if self.train_config.do_cfg:
embeds = [
load_file(random.choice(batch.clip_image_embeds_unconditional)) for i in
range(noisy_latents.shape[0])
]
unconditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache(
embeds,
quad_count=quad_count
)
else:
conditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache(
batch.clip_image_embeds,
quad_count=quad_count
)
if self.train_config.do_cfg:
unconditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache(
batch.clip_image_embeds_unconditional,
quad_count=quad_count
)
elif is_reg:
# we will zero it out in the img embedder
clip_images = torch.zeros(
(noisy_latents.shape[0], 3, image_size, image_size),
device=self.device_torch, dtype=dtype
).detach()
# drop will zero it out
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
clip_images,
drop=True,
is_training=True,
has_been_preprocessed=False,
quad_count=quad_count
)
if self.train_config.do_cfg:
unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
torch.zeros(
(noisy_latents.shape[0], 3, image_size, image_size),
device=self.device_torch, dtype=dtype
).detach(),
is_training=True,
drop=True,
has_been_preprocessed=False,
quad_count=quad_count
)
elif has_clip_image:
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
clip_images.detach().to(self.device_torch, dtype=dtype),
is_training=True,
has_been_preprocessed=True,
quad_count=quad_count,
# do cfg on clip embeds to normalize the embeddings for when doing cfg
# cfg_embed_strength=3.0 if not self.train_config.do_cfg else None
# cfg_embed_strength=3.0 if not self.train_config.do_cfg else None
)
if self.train_config.do_cfg:
unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
clip_images.detach().to(self.device_torch, dtype=dtype),
is_training=True,
drop=True,
has_been_preprocessed=True,
quad_count=quad_count
)
else:
print("No Clip Image")
print([file_item.path for file_item in batch.file_items])
raise ValueError("Could not find clip image")
if not self.adapter_config.train_image_encoder:
# we are not training the image encoder, so we need to detach the embeds
conditional_clip_embeds = conditional_clip_embeds.detach()
if self.train_config.do_cfg:
unconditional_clip_embeds = unconditional_clip_embeds.detach()
with self.timer('encode_adapter'):
self.adapter.train()
conditional_embeds = self.adapter(
conditional_embeds.detach(),
conditional_clip_embeds,
is_unconditional=False
)
if self.train_config.do_cfg:
unconditional_embeds = self.adapter(
unconditional_embeds.detach(),
unconditional_clip_embeds,
is_unconditional=True
)
else:
# wipe out unconsitional
self.adapter.last_unconditional = None
if self.adapter and isinstance(self.adapter, ReferenceAdapter):
# pass in our scheduler
self.adapter.noise_scheduler = self.lr_scheduler
if has_clip_image or has_adapter_img:
img_to_use = clip_images if has_clip_image else adapter_images
# currently 0-1 needs to be -1 to 1
reference_images = ((img_to_use - 0.5) * 2).detach().to(self.device_torch, dtype=dtype)
self.adapter.set_reference_images(reference_images)
self.adapter.noise_scheduler = self.sd.noise_scheduler
elif is_reg:
self.adapter.set_blank_reference_images(noisy_latents.shape[0])
else:
self.adapter.set_reference_images(None)
prior_pred = None
do_reg_prior = False
# if is_reg and (self.network is not None or self.adapter is not None):
# # we are doing a reg image and we have a network or adapter
# do_reg_prior = True
do_inverted_masked_prior = False
if self.train_config.inverted_mask_prior and batch.mask_tensor is not None:
do_inverted_masked_prior = True
do_correct_pred_norm_prior = self.train_config.correct_pred_norm
do_guidance_prior = False
if batch.unconditional_latents is not None:
# for this not that, we need a prior pred to normalize
guidance_type: GuidanceType = batch.file_items[0].dataset_config.guidance_type
if guidance_type == 'tnt':
do_guidance_prior = True
if ((
has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction or do_guidance_prior or do_reg_prior or do_inverted_masked_prior or self.train_config.correct_pred_norm):
with self.timer('prior predict'):
prior_pred = self.get_prior_prediction(
noisy_latents=noisy_latents,
conditional_embeds=conditional_embeds,
match_adapter_assist=match_adapter_assist,
network_weight_list=network_weight_list,
timesteps=timesteps,
pred_kwargs=pred_kwargs,
noise=noise,
batch=batch,
unconditional_embeds=unconditional_embeds,
conditioned_prompts=conditioned_prompts
)
if prior_pred is not None:
prior_pred = prior_pred.detach()
# do the custom adapter after the prior prediction
if self.adapter and isinstance(self.adapter, CustomAdapter) and has_clip_image:
quad_count = random.randint(1, 4)
self.adapter.train()
conditional_embeds = self.adapter.condition_encoded_embeds(
tensors_0_1=clip_images,
prompt_embeds=conditional_embeds,
is_training=True,
has_been_preprocessed=True,
quad_count=quad_count
)
if self.train_config.do_cfg and unconditional_embeds is not None:
unconditional_embeds = self.adapter.condition_encoded_embeds(
tensors_0_1=clip_images,
prompt_embeds=unconditional_embeds,
is_training=True,
has_been_preprocessed=True,
is_unconditional=True,
quad_count=quad_count
)
if self.adapter and isinstance(self.adapter, CustomAdapter) and batch.extra_values is not None:
self.adapter.add_extra_values(batch.extra_values.detach())
if self.train_config.do_cfg:
self.adapter.add_extra_values(torch.zeros_like(batch.extra_values.detach()),
is_unconditional=True)
if has_adapter_img:
if (self.adapter and isinstance(self.adapter, ControlNetModel)) or (
self.assistant_adapter and isinstance(self.assistant_adapter, ControlNetModel)):
if self.train_config.do_cfg:
raise ValueError("ControlNetModel is not supported with CFG")
with torch.set_grad_enabled(self.adapter is not None):
adapter: ControlNetModel = self.assistant_adapter if self.assistant_adapter is not None else self.adapter
adapter_multiplier = get_adapter_multiplier()
with self.timer('encode_adapter'):
# add_text_embeds is pooled_prompt_embeds for sdxl
added_cond_kwargs = {}
if self.sd.is_xl:
added_cond_kwargs["text_embeds"] = conditional_embeds.pooled_embeds
added_cond_kwargs['time_ids'] = self.sd.get_time_ids_from_latents(noisy_latents)
down_block_res_samples, mid_block_res_sample = adapter(
noisy_latents,
timesteps,
encoder_hidden_states=conditional_embeds.text_embeds,
controlnet_cond=adapter_images,
conditioning_scale=1.0,
guess_mode=False,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)
pred_kwargs['down_block_additional_residuals'] = down_block_res_samples
pred_kwargs['mid_block_additional_residual'] = mid_block_res_sample
self.before_unet_predict()
# do a prior pred if we have an unconditional image, we will swap out the giadance later
if batch.unconditional_latents is not None or self.do_guided_loss:
# do guided loss
loss = self.get_guided_loss(
noisy_latents=noisy_latents,
conditional_embeds=conditional_embeds,
match_adapter_assist=match_adapter_assist,
network_weight_list=network_weight_list,
timesteps=timesteps,
pred_kwargs=pred_kwargs,
batch=batch,
noise=noise,
unconditional_embeds=unconditional_embeds,
mask_multiplier=mask_multiplier,
prior_pred=prior_pred,
)
else:
with self.timer('predict_unet'):
if unconditional_embeds is not None:
unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach()
noise_pred = self.predict_noise(
noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype),
timesteps=timesteps,
conditional_embeds=conditional_embeds.to(self.device_torch, dtype=dtype),
unconditional_embeds=unconditional_embeds,
**pred_kwargs
)
self.after_unet_predict()
with self.timer('calculate_loss'):
noise = noise.to(self.device_torch, dtype=dtype).detach()
loss = self.calculate_loss(
noise_pred=noise_pred,
noise=noise,
noisy_latents=noisy_latents,
timesteps=timesteps,
batch=batch,
mask_multiplier=mask_multiplier,
prior_pred=prior_pred,
)
loss = loss * loss_multiplier.mean()
return loss