Bug fixes. added ability to use l1 loss. varous other tests and improvements

This commit is contained in:
Jaret Burkett
2024-01-31 06:30:54 -07:00
parent 92b9c71d44
commit 1ae1017748
9 changed files with 474 additions and 23 deletions

View File

@@ -6,10 +6,14 @@ import numpy as np
from diffusers import T2IAdapter, AutoencoderTiny
import torch.functional as F
from safetensors.torch import load_file
from torch.utils.data import DataLoader, ConcatDataset
from toolkit import train_tools
from toolkit.basic import value_map, adain, get_mean_std
from toolkit.clip_vision_adapter import ClipVisionAdapter
from toolkit.config_modules import GuidanceConfig
from toolkit.data_loader import get_dataloader_datasets
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO, FileItemDTO
from toolkit.guidance import get_targeted_guidance_loss, get_guidance_loss
from toolkit.image_utils import show_tensors, show_latents
@@ -46,6 +50,8 @@ class SDTrainer(BaseSDTrainProcess):
self.do_guided_loss = False
self.taesd: Optional[AutoencoderTiny] = None
self._clip_image_embeds_unconditional: Union[List[str], None] = None
def before_model_load(self):
pass
@@ -86,6 +92,22 @@ class SDTrainer(BaseSDTrainProcess):
if self.adapter is not None:
self.adapter.to(self.device_torch)
# check if we have regs and using adapter and caching clip embeddings
has_reg = self.datasets_reg is not None and len(self.datasets_reg) > 0
is_caching_clip_embeddings = any([self.datasets[i].cache_clip_vision_to_disk for i in range(len(self.datasets))])
if has_reg and is_caching_clip_embeddings:
# we need a list of unconditional clip image embeds from other datasets to handle regs
unconditional_clip_image_embeds = []
datasets = get_dataloader_datasets(self.data_loader)
for i in range(len(datasets)):
unconditional_clip_image_embeds += datasets[i].clip_vision_unconditional_cache
if len(unconditional_clip_image_embeds) == 0:
raise ValueError("No unconditional clip image embeds found. This should not happen")
self._clip_image_embeds_unconditional = unconditional_clip_image_embeds
def process_output_for_turbo(self, pred, noisy_latents, timesteps, noise, batch):
# to process turbo learning, we make one big step from our current timestep to the end
# we then denoise the prediction on that remaining step and target our loss to our target latents
@@ -190,6 +212,7 @@ class SDTrainer(BaseSDTrainProcess):
**kwargs
):
loss_target = self.train_config.loss_target
is_reg = any(batch.get_is_reg_list())
prior_mask_multiplier = None
target_mask_multiplier = None
@@ -202,6 +225,40 @@ class SDTrainer(BaseSDTrainProcess):
noise_pred_norm = torch.linalg.vector_norm(noise_pred, ord=2, dim=(1, 2, 3), keepdim=True)
noise_pred = noise_pred * (noise_norm / noise_pred_norm)
if self.train_config.correct_pred_norm and not is_reg:
with torch.no_grad():
# adjust the noise target in the opposite direction of the noise pred mean and std offset
# this will apply additional force the model to correct itself to match the norm of the noise
noise_pred_mean, noise_pred_std = get_mean_std(noise_pred)
noise_mean, noise_std = get_mean_std(noise)
# apply the inverse offset of the mean and std to the noise
noise_additional_mean = noise_mean - noise_pred_mean
noise_additional_std = noise_std - noise_pred_std
# adjust for multiplier
noise_additional_mean = noise_additional_mean * self.train_config.correct_pred_norm_multiplier
noise_additional_std = noise_additional_std * self.train_config.correct_pred_norm_multiplier
noise_target_std = noise_std + noise_additional_std
noise_target_mean = noise_mean + noise_additional_mean
noise_pred_target_std = noise_pred_std - noise_additional_std
noise_pred_target_mean = noise_pred_mean - noise_additional_mean
noise_pred_target_std = noise_pred_target_std.detach()
noise_pred_target_mean = noise_pred_target_mean.detach()
# match the noise to the target
noise = (noise - noise_mean) / noise_std
noise = noise * noise_target_std + noise_target_mean
noise = noise.detach()
# meatch the noise pred to the target
# noise_pred = (noise_pred - noise_pred_mean) / noise_pred_std
# noise_pred = noise_pred * noise_pred_target_std + noise_pred_target_mean
if self.train_config.inverted_mask_prior and prior_pred is not None and has_mask:
assert not self.train_config.train_turbo
# we need to make the noise prediction be a masked blending of noise and prior_pred
@@ -227,7 +284,7 @@ class SDTrainer(BaseSDTrainProcess):
target = prior_pred
elif self.sd.prediction_type == 'v_prediction':
# v-parameterization training
target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
target = self.sd.noise_scheduler.get_velocity(batch.tensor, noise, timesteps)
else:
target = noise
@@ -270,7 +327,10 @@ class SDTrainer(BaseSDTrainProcess):
loss_per_element = (weighing.float() * (denoised_latents.float() - target.float()) ** 2)
loss = loss_per_element
else:
loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none")
if self.train_config.loss_type == "mae":
loss = torch.nn.functional.l1_loss(pred.float(), target.float(), reduction="none")
else:
loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none")
# multiply by our mask
loss = loss * mask_multiplier
@@ -278,12 +338,11 @@ class SDTrainer(BaseSDTrainProcess):
prior_loss = None
if self.train_config.inverted_mask_prior and prior_pred is not None and prior_mask_multiplier is not None:
assert not self.train_config.train_turbo
# to a loss to unmasked areas of the prior for unmasked regularization
prior_loss = torch.nn.functional.mse_loss(
prior_pred.float(),
pred.float(),
reduction="none"
)
if self.train_config.loss_type == "mae":
prior_loss = torch.nn.functional.l1_loss(pred.float(), prior_pred.float(), reduction="none")
else:
prior_loss = torch.nn.functional.mse_loss(pred.float(), prior_pred.float(), reduction="none")
prior_loss = prior_loss * prior_mask_multiplier * self.train_config.inverted_mask_prior_multiplier
if torch.isnan(prior_loss).any():
print("Prior loss is nan")
@@ -717,6 +776,13 @@ class SDTrainer(BaseSDTrainProcess):
has_adapter_img = batch.control_tensor is not None
has_clip_image = batch.clip_image_tensor is not None
has_clip_image_embeds = batch.clip_image_embeds is not None
# force it to be true if doing regs as we handle those differently
if any([batch.file_items[idx].is_reg for idx in range(len(batch.file_items))]):
has_clip_image = True
if self._clip_image_embeds_unconditional is not None:
has_clip_image_embeds = True # we are caching embeds, handle that differently
has_clip_image = False
if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not has_clip_image and has_adapter_img:
raise ValueError(
@@ -996,7 +1062,39 @@ class SDTrainer(BaseSDTrainProcess):
# number of images to do if doing a quad image
quad_count = random.randint(1, 4)
image_size = self.adapter.input_size
if is_reg:
if has_clip_image_embeds:
# todo handle reg images better than this
if is_reg:
# get unconditional image imbeds from cache
embeds = [
load_file(random.choice(batch.clip_image_embeds_unconditional)) for i in
range(noisy_latents.shape[0])
]
conditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache(
embeds,
quad_count=quad_count
)
if self.train_config.do_cfg:
embeds = [
load_file(random.choice(batch.clip_image_embeds_unconditional)) for i in range(noisy_latents.shape[0])
]
unconditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache(
embeds,
quad_count=quad_count
)
else:
conditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache(
batch.clip_image_embeds,
quad_count=quad_count
)
if self.train_config.do_cfg:
unconditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache(
batch.clip_image_embeds_unconditional,
quad_count=quad_count
)
elif is_reg:
# we will zero it out in the img embedder
clip_images = torch.zeros(
(noisy_latents.shape[0], 3, image_size, image_size),
@@ -1071,9 +1169,9 @@ class SDTrainer(BaseSDTrainProcess):
prior_pred = None
do_reg_prior = False
# if is_reg and (self.network is not None or self.adapter is not None):
# # we are doing a reg image and we have a network or adapter
# do_reg_prior = True
if is_reg and (self.network is not None or self.adapter is not None):
# we are doing a reg image and we have a network or adapter
do_reg_prior = True
do_inverted_masked_prior = False
if self.train_config.inverted_mask_prior and batch.mask_tensor is not None:
@@ -1096,12 +1194,14 @@ class SDTrainer(BaseSDTrainProcess):
# do the custom adapter after the prior prediction
if self.adapter and isinstance(self.adapter, CustomAdapter) and has_clip_image:
quad_count = random.randint(1, 4)
self.adapter.train()
conditional_embeds = self.adapter.condition_encoded_embeds(
tensors_0_1=clip_images,
prompt_embeds=conditional_embeds,
is_training=True,
has_been_preprocessed=True,
quad_count=quad_count
)
if self.train_config.do_cfg and unconditional_embeds is not None:
unconditional_embeds = self.adapter.condition_encoded_embeds(
@@ -1109,7 +1209,8 @@ class SDTrainer(BaseSDTrainProcess):
prompt_embeds=unconditional_embeds,
is_training=True,
has_been_preprocessed=True,
is_unconditional=True
is_unconditional=True,
quad_count=quad_count
)