mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 02:31:17 +00:00
Bug fixes. added ability to use l1 loss. varous other tests and improvements
This commit is contained in:
@@ -6,10 +6,14 @@ import numpy as np
|
||||
from diffusers import T2IAdapter, AutoencoderTiny
|
||||
|
||||
import torch.functional as F
|
||||
from safetensors.torch import load_file
|
||||
from torch.utils.data import DataLoader, ConcatDataset
|
||||
|
||||
from toolkit import train_tools
|
||||
from toolkit.basic import value_map, adain, get_mean_std
|
||||
from toolkit.clip_vision_adapter import ClipVisionAdapter
|
||||
from toolkit.config_modules import GuidanceConfig
|
||||
from toolkit.data_loader import get_dataloader_datasets
|
||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO, FileItemDTO
|
||||
from toolkit.guidance import get_targeted_guidance_loss, get_guidance_loss
|
||||
from toolkit.image_utils import show_tensors, show_latents
|
||||
@@ -46,6 +50,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.do_guided_loss = False
|
||||
self.taesd: Optional[AutoencoderTiny] = None
|
||||
|
||||
self._clip_image_embeds_unconditional: Union[List[str], None] = None
|
||||
|
||||
def before_model_load(self):
|
||||
pass
|
||||
|
||||
@@ -86,6 +92,22 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if self.adapter is not None:
|
||||
self.adapter.to(self.device_torch)
|
||||
|
||||
# check if we have regs and using adapter and caching clip embeddings
|
||||
has_reg = self.datasets_reg is not None and len(self.datasets_reg) > 0
|
||||
is_caching_clip_embeddings = any([self.datasets[i].cache_clip_vision_to_disk for i in range(len(self.datasets))])
|
||||
|
||||
if has_reg and is_caching_clip_embeddings:
|
||||
# we need a list of unconditional clip image embeds from other datasets to handle regs
|
||||
unconditional_clip_image_embeds = []
|
||||
datasets = get_dataloader_datasets(self.data_loader)
|
||||
for i in range(len(datasets)):
|
||||
unconditional_clip_image_embeds += datasets[i].clip_vision_unconditional_cache
|
||||
|
||||
if len(unconditional_clip_image_embeds) == 0:
|
||||
raise ValueError("No unconditional clip image embeds found. This should not happen")
|
||||
|
||||
self._clip_image_embeds_unconditional = unconditional_clip_image_embeds
|
||||
|
||||
def process_output_for_turbo(self, pred, noisy_latents, timesteps, noise, batch):
|
||||
# to process turbo learning, we make one big step from our current timestep to the end
|
||||
# we then denoise the prediction on that remaining step and target our loss to our target latents
|
||||
@@ -190,6 +212,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
**kwargs
|
||||
):
|
||||
loss_target = self.train_config.loss_target
|
||||
is_reg = any(batch.get_is_reg_list())
|
||||
|
||||
prior_mask_multiplier = None
|
||||
target_mask_multiplier = None
|
||||
@@ -202,6 +225,40 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
noise_pred_norm = torch.linalg.vector_norm(noise_pred, ord=2, dim=(1, 2, 3), keepdim=True)
|
||||
noise_pred = noise_pred * (noise_norm / noise_pred_norm)
|
||||
|
||||
if self.train_config.correct_pred_norm and not is_reg:
|
||||
with torch.no_grad():
|
||||
|
||||
# adjust the noise target in the opposite direction of the noise pred mean and std offset
|
||||
# this will apply additional force the model to correct itself to match the norm of the noise
|
||||
noise_pred_mean, noise_pred_std = get_mean_std(noise_pred)
|
||||
noise_mean, noise_std = get_mean_std(noise)
|
||||
|
||||
# apply the inverse offset of the mean and std to the noise
|
||||
noise_additional_mean = noise_mean - noise_pred_mean
|
||||
noise_additional_std = noise_std - noise_pred_std
|
||||
|
||||
# adjust for multiplier
|
||||
noise_additional_mean = noise_additional_mean * self.train_config.correct_pred_norm_multiplier
|
||||
noise_additional_std = noise_additional_std * self.train_config.correct_pred_norm_multiplier
|
||||
|
||||
noise_target_std = noise_std + noise_additional_std
|
||||
noise_target_mean = noise_mean + noise_additional_mean
|
||||
|
||||
|
||||
noise_pred_target_std = noise_pred_std - noise_additional_std
|
||||
noise_pred_target_mean = noise_pred_mean - noise_additional_mean
|
||||
noise_pred_target_std = noise_pred_target_std.detach()
|
||||
noise_pred_target_mean = noise_pred_target_mean.detach()
|
||||
|
||||
# match the noise to the target
|
||||
noise = (noise - noise_mean) / noise_std
|
||||
noise = noise * noise_target_std + noise_target_mean
|
||||
noise = noise.detach()
|
||||
|
||||
# meatch the noise pred to the target
|
||||
# noise_pred = (noise_pred - noise_pred_mean) / noise_pred_std
|
||||
# noise_pred = noise_pred * noise_pred_target_std + noise_pred_target_mean
|
||||
|
||||
if self.train_config.inverted_mask_prior and prior_pred is not None and has_mask:
|
||||
assert not self.train_config.train_turbo
|
||||
# we need to make the noise prediction be a masked blending of noise and prior_pred
|
||||
@@ -227,7 +284,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
target = prior_pred
|
||||
elif self.sd.prediction_type == 'v_prediction':
|
||||
# v-parameterization training
|
||||
target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
|
||||
target = self.sd.noise_scheduler.get_velocity(batch.tensor, noise, timesteps)
|
||||
else:
|
||||
target = noise
|
||||
|
||||
@@ -270,7 +327,10 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
loss_per_element = (weighing.float() * (denoised_latents.float() - target.float()) ** 2)
|
||||
loss = loss_per_element
|
||||
else:
|
||||
loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none")
|
||||
if self.train_config.loss_type == "mae":
|
||||
loss = torch.nn.functional.l1_loss(pred.float(), target.float(), reduction="none")
|
||||
else:
|
||||
loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none")
|
||||
|
||||
# multiply by our mask
|
||||
loss = loss * mask_multiplier
|
||||
@@ -278,12 +338,11 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
prior_loss = None
|
||||
if self.train_config.inverted_mask_prior and prior_pred is not None and prior_mask_multiplier is not None:
|
||||
assert not self.train_config.train_turbo
|
||||
# to a loss to unmasked areas of the prior for unmasked regularization
|
||||
prior_loss = torch.nn.functional.mse_loss(
|
||||
prior_pred.float(),
|
||||
pred.float(),
|
||||
reduction="none"
|
||||
)
|
||||
if self.train_config.loss_type == "mae":
|
||||
prior_loss = torch.nn.functional.l1_loss(pred.float(), prior_pred.float(), reduction="none")
|
||||
else:
|
||||
prior_loss = torch.nn.functional.mse_loss(pred.float(), prior_pred.float(), reduction="none")
|
||||
|
||||
prior_loss = prior_loss * prior_mask_multiplier * self.train_config.inverted_mask_prior_multiplier
|
||||
if torch.isnan(prior_loss).any():
|
||||
print("Prior loss is nan")
|
||||
@@ -717,6 +776,13 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
has_adapter_img = batch.control_tensor is not None
|
||||
has_clip_image = batch.clip_image_tensor is not None
|
||||
has_clip_image_embeds = batch.clip_image_embeds is not None
|
||||
# force it to be true if doing regs as we handle those differently
|
||||
if any([batch.file_items[idx].is_reg for idx in range(len(batch.file_items))]):
|
||||
has_clip_image = True
|
||||
if self._clip_image_embeds_unconditional is not None:
|
||||
has_clip_image_embeds = True # we are caching embeds, handle that differently
|
||||
has_clip_image = False
|
||||
|
||||
if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not has_clip_image and has_adapter_img:
|
||||
raise ValueError(
|
||||
@@ -996,7 +1062,39 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# number of images to do if doing a quad image
|
||||
quad_count = random.randint(1, 4)
|
||||
image_size = self.adapter.input_size
|
||||
if is_reg:
|
||||
if has_clip_image_embeds:
|
||||
# todo handle reg images better than this
|
||||
if is_reg:
|
||||
# get unconditional image imbeds from cache
|
||||
embeds = [
|
||||
load_file(random.choice(batch.clip_image_embeds_unconditional)) for i in
|
||||
range(noisy_latents.shape[0])
|
||||
]
|
||||
conditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache(
|
||||
embeds,
|
||||
quad_count=quad_count
|
||||
)
|
||||
|
||||
if self.train_config.do_cfg:
|
||||
embeds = [
|
||||
load_file(random.choice(batch.clip_image_embeds_unconditional)) for i in range(noisy_latents.shape[0])
|
||||
]
|
||||
unconditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache(
|
||||
embeds,
|
||||
quad_count=quad_count
|
||||
)
|
||||
|
||||
else:
|
||||
conditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache(
|
||||
batch.clip_image_embeds,
|
||||
quad_count=quad_count
|
||||
)
|
||||
if self.train_config.do_cfg:
|
||||
unconditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache(
|
||||
batch.clip_image_embeds_unconditional,
|
||||
quad_count=quad_count
|
||||
)
|
||||
elif is_reg:
|
||||
# we will zero it out in the img embedder
|
||||
clip_images = torch.zeros(
|
||||
(noisy_latents.shape[0], 3, image_size, image_size),
|
||||
@@ -1071,9 +1169,9 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
prior_pred = None
|
||||
|
||||
do_reg_prior = False
|
||||
# if is_reg and (self.network is not None or self.adapter is not None):
|
||||
# # we are doing a reg image and we have a network or adapter
|
||||
# do_reg_prior = True
|
||||
if is_reg and (self.network is not None or self.adapter is not None):
|
||||
# we are doing a reg image and we have a network or adapter
|
||||
do_reg_prior = True
|
||||
|
||||
do_inverted_masked_prior = False
|
||||
if self.train_config.inverted_mask_prior and batch.mask_tensor is not None:
|
||||
@@ -1096,12 +1194,14 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
# do the custom adapter after the prior prediction
|
||||
if self.adapter and isinstance(self.adapter, CustomAdapter) and has_clip_image:
|
||||
quad_count = random.randint(1, 4)
|
||||
self.adapter.train()
|
||||
conditional_embeds = self.adapter.condition_encoded_embeds(
|
||||
tensors_0_1=clip_images,
|
||||
prompt_embeds=conditional_embeds,
|
||||
is_training=True,
|
||||
has_been_preprocessed=True,
|
||||
quad_count=quad_count
|
||||
)
|
||||
if self.train_config.do_cfg and unconditional_embeds is not None:
|
||||
unconditional_embeds = self.adapter.condition_encoded_embeds(
|
||||
@@ -1109,7 +1209,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
prompt_embeds=unconditional_embeds,
|
||||
is_training=True,
|
||||
has_been_preprocessed=True,
|
||||
is_unconditional=True
|
||||
is_unconditional=True,
|
||||
quad_count=quad_count
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user