Added base for using guidance during training. Still not working right.

This commit is contained in:
Jaret Burkett
2023-11-05 04:03:32 -07:00
parent d35733ac06
commit 8a9e8f708f
5 changed files with 245 additions and 25 deletions

View File

@@ -1,6 +1,8 @@
from collections import OrderedDict
from typing import Union
from diffusers import T2IAdapter
from toolkit import train_tools
from toolkit.basic import value_map
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
from toolkit.ip_adapter import IPAdapter
@@ -30,6 +32,7 @@ class SDTrainer(BaseSDTrainProcess):
super().__init__(process_id, job, config, **kwargs)
self.assistant_adapter: Union['T2IAdapter', None]
self.do_prior_prediction = False
self.target_class = self.get_conf('target_class', '')
if self.train_config.inverted_mask_prior:
self.do_prior_prediction = True
@@ -171,6 +174,99 @@ class SDTrainer(BaseSDTrainProcess):
def preprocess_batch(self, batch: 'DataLoaderBatchDTO'):
return batch
def get_guided_loss(
self,
noisy_latents: torch.Tensor,
conditional_embeds: PromptEmbeds,
match_adapter_assist: bool,
network_weight_list: list,
timesteps: torch.Tensor,
pred_kwargs: dict,
batch: 'DataLoaderBatchDTO',
noise: torch.Tensor,
**kwargs
):
with torch.no_grad():
dtype = get_torch_dtype(self.train_config.dtype)
# target class is unconditional
target_class_embeds = self.sd.encode_prompt(self.target_class).detach()
if batch.unconditional_latents is not None:
# do the unconditional prediction here instead of a prior prediction
unconditional_noisy_latents = self.sd.noise_scheduler.add_noise(batch.unconditional_latents, noise,
timesteps)
was_network_active = self.network.is_active
self.network.is_active = False
self.sd.unet.eval()
guidance_scale = 1.0
def cfg(uncon, con):
return uncon + guidance_scale * (
con - uncon
)
target_conditional = self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(),
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
timestep=timesteps,
guidance_scale=1.0,
**pred_kwargs # adapter residuals in here
).detach()
target_unconditional = self.sd.predict_noise(
latents=unconditional_noisy_latents.to(self.device_torch, dtype=dtype).detach(),
conditional_embeddings=target_class_embeds.to(self.device_torch, dtype=dtype).detach(),
timestep=timesteps,
guidance_scale=1.0,
**pred_kwargs # adapter residuals in here
).detach()
neutral_latents = (noisy_latents + unconditional_noisy_latents) / 2.0
target_noise = cfg(target_unconditional, target_conditional)
# latents = self.noise_scheduler.step(target_noise, timesteps, noisy_latents, return_dict=False)[0]
# target_pred = target_pred - noisy_latents + (unconditional_noisy_latents - noise)
# target_noise_res = noisy_latents - unconditional_noisy_latents
# target_pred = cfg(unconditional_noisy_latents, target_pred)
# target_pred = target_pred + target_noise_res
self.network.is_active = True
self.sd.unet.train()
prediction = self.sd.predict_noise(
latents=neutral_latents.to(self.device_torch, dtype=dtype).detach(),
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
timestep=timesteps,
guidance_scale=1.0,
**pred_kwargs # adapter residuals in here
)
# prediction_res = target_pred - prediction
# prediction = cfg(prediction, target_pred)
loss = torch.nn.functional.mse_loss(prediction.float(), target_noise.float(), reduction="none")
loss = loss.mean([1, 2, 3])
if self.train_config.learnable_snr_gos:
# add snr_gamma
loss = apply_learnable_snr_gos(loss, timesteps, self.snr_gos)
elif self.train_config.snr_gamma is not None and self.train_config.snr_gamma > 0.000001:
# add snr_gamma
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.snr_gamma, fixed=True)
elif self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
# add min_snr_gamma
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
loss = loss.mean()
return loss
def get_prior_prediction(
self,
noisy_latents: torch.Tensor,
@@ -369,8 +465,6 @@ class SDTrainer(BaseSDTrainProcess):
else:
prompt_2_list = [prompts_2]
for noisy_latents, noise, timesteps, conditioned_prompts, imgs, adapter_images, mask_multiplier, prompt_2 in zip(
noisy_latents_list,
noise_list,
@@ -386,8 +480,9 @@ class SDTrainer(BaseSDTrainProcess):
with self.timer('encode_prompt'):
if grad_on_text_encoder:
with torch.set_grad_enabled(True):
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=True).to(
# conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=False).to(
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2,
long_prompts=True).to(
# conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=False).to(
self.device_torch,
dtype=dtype)
else:
@@ -398,8 +493,9 @@ class SDTrainer(BaseSDTrainProcess):
te.eval()
else:
self.sd.text_encoder.eval()
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=True).to(
# conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=False).to(
conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2,
long_prompts=True).to(
# conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=False).to(
self.device_torch,
dtype=dtype)
@@ -450,27 +546,42 @@ class SDTrainer(BaseSDTrainProcess):
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
self.before_unet_predict()
with self.timer('predict_unet'):
noise_pred = self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype),
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
timestep=timesteps,
guidance_scale=1.0,
**pred_kwargs
)
self.after_unet_predict()
with self.timer('calculate_loss'):
noise = noise.to(self.device_torch, dtype=dtype).detach()
loss = self.calculate_loss(
noise_pred=noise_pred,
noise=noise,
# do a prior pred if we have an unconditional image, we will swap out the giadance later
if batch.unconditional_latents is not None:
# do guided loss
loss = self.get_guided_loss(
noisy_latents=noisy_latents,
conditional_embeds=conditional_embeds,
match_adapter_assist=match_adapter_assist,
network_weight_list=network_weight_list,
timesteps=timesteps,
pred_kwargs=pred_kwargs,
batch=batch,
mask_multiplier=mask_multiplier,
prior_pred=prior_pred,
noise=noise,
)
else:
with self.timer('predict_unet'):
noise_pred = self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype),
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
timestep=timesteps,
guidance_scale=1.0,
**pred_kwargs
)
self.after_unet_predict()
with self.timer('calculate_loss'):
noise = noise.to(self.device_torch, dtype=dtype).detach()
loss = self.calculate_loss(
noise_pred=noise_pred,
noise=noise,
noisy_latents=noisy_latents,
timesteps=timesteps,
batch=batch,
mask_multiplier=mask_multiplier,
prior_pred=prior_pred,
)
# check if nan
if torch.isnan(loss):
raise ValueError("loss is nan")

View File

@@ -574,7 +574,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
else:
latents = self.sd.encode_images(imgs)
batch.latents = latents
# flush() # todo check performance removing this
if batch.unconditional_tensor is not None and batch.unconditional_latents is None:
unconditional_imgs = batch.unconditional_tensor
unconditional_imgs = unconditional_imgs.to(self.device_torch, dtype=dtype)
unconditional_latents = self.sd.encode_images(unconditional_imgs)
batch.unconditional_latents = unconditional_latents
unaugmented_latents = None
if self.train_config.loss_target == 'differential_noise':
@@ -655,6 +660,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
noisy_latents = self.sd.noise_scheduler.add_noise(latents, noise, timesteps)
# determine scaled noise
# todo do we need to scale this or does it always predict full intensity
# noise = noisy_latents - latents
# https://github.com/huggingface/diffusers/blob/324d18fba23f6c9d7475b0ff7c777685f7128d40/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1170C17-L1171C77
if self.train_config.loss_target == 'source' or self.train_config.loss_target == 'unaugmented':
sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype)

View File

@@ -351,6 +351,7 @@ class DatasetConfig:
self.alpha_mask: bool = kwargs.get('alpha_mask', False) # if true, will use alpha channel as mask
self.mask_path: str = kwargs.get('mask_path',
None) # focus mask (black and white. White has higher loss than black)
self.unconditional_path: str = kwargs.get('unconditional_path', None) # path where matching unconditional images are located
self.invert_mask: bool = kwargs.get('invert_mask', False) # invert mask
self.mask_min_value: float = kwargs.get('mask_min_value', 0.01) # min value for . 0 - 1
self.poi: Union[str, None] = kwargs.get('poi',

View File

@@ -7,7 +7,8 @@ from PIL.ImageOps import exif_transpose
from toolkit import image_utils
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \
ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin, AugmentationFileItemDTOMixin
ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin, AugmentationFileItemDTOMixin, \
UnconditionalFileItemDTOMixin
if TYPE_CHECKING:
from toolkit.config_modules import DatasetConfig
@@ -29,6 +30,7 @@ class FileItemDTO(
ControlFileItemDTOMixin,
MaskFileItemDTOMixin,
AugmentationFileItemDTOMixin,
UnconditionalFileItemDTOMixin,
PoiFileItemDTOMixin,
ArgBreakMixin,
):
@@ -70,6 +72,7 @@ class FileItemDTO(
self.cleanup_latent()
self.cleanup_control()
self.cleanup_mask()
self.cleanup_unconditional()
class DataLoaderBatchDTO:
@@ -82,6 +85,8 @@ class DataLoaderBatchDTO:
self.control_tensor: Union[torch.Tensor, None] = None
self.mask_tensor: Union[torch.Tensor, None] = None
self.unaugmented_tensor: Union[torch.Tensor, None] = None
self.unconditional_tensor: Union[torch.Tensor, None] = None
self.unconditional_latents: Union[torch.Tensor, None] = None
self.sigmas: Union[torch.Tensor, None] = None # can be added elseware and passed along training code
if not is_latents_cached:
# only return a tensor if latents are not cached
@@ -138,6 +143,22 @@ class DataLoaderBatchDTO:
else:
unaugmented_tensor.append(x.unaugmented_tensor)
self.unaugmented_tensor = torch.cat([x.unsqueeze(0) for x in unaugmented_tensor])
# add unconditional tensors
if any([x.unconditional_tensor is not None for x in self.file_items]):
# find one to use as a base
base_unconditional_tensor = None
for x in self.file_items:
if x.unaugmented_tensor is not None:
base_unconditional_tensor = x.unconditional_tensor
break
unconditional_tensor = []
for x in self.file_items:
if x.unconditional_tensor is None:
unconditional_tensor.append(torch.zeros_like(base_unconditional_tensor))
else:
unconditional_tensor.append(x.unconditional_tensor)
self.unconditional_tensor = torch.cat([x.unsqueeze(0) for x in unconditional_tensor])
except Exception as e:
print(e)
raise e

View File

@@ -351,6 +351,8 @@ class ImageProcessingDTOMixin:
self.load_control_image()
if self.has_mask_image:
self.load_mask_image()
if self.has_unconditional:
self.load_unconditional_image()
return
try:
img = Image.open(self.path)
@@ -442,6 +444,8 @@ class ImageProcessingDTOMixin:
self.load_control_image()
if self.has_mask_image:
self.load_mask_image()
if self.has_unconditional:
self.load_unconditional_image()
class ControlFileItemDTOMixin:
@@ -661,6 +665,80 @@ class MaskFileItemDTOMixin:
self.mask_tensor = None
class UnconditionalFileItemDTOMixin:
def __init__(self: 'FileItemDTO', *args, **kwargs):
if hasattr(super(), '__init__'):
super().__init__(*args, **kwargs)
self.has_unconditional = False
self.unconditional_path: Union[str, None] = None
self.unconditional_tensor: Union[torch.Tensor, None] = None
self.unconditional_latent: Union[torch.Tensor, None] = None
self.unconditional_transforms = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
if dataset_config.unconditional_path is not None:
# we are using control images
img_path = kwargs.get('path', None)
img_ext_list = ['.jpg', '.jpeg', '.png', '.webp']
file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0]
for ext in img_ext_list:
if os.path.exists(os.path.join(dataset_config.unconditional_path, file_name_no_ext + ext)):
self.unconditional_path = os.path.join(dataset_config.unconditional_path, file_name_no_ext + ext)
self.has_unconditional = True
break
def load_unconditional_image(self: 'FileItemDTO'):
try:
img = Image.open(self.unconditional_path)
img = exif_transpose(img)
except Exception as e:
print(f"Error: {e}")
print(f"Error loading image: {self.mask_path}")
img = img.convert('RGB')
w, h = img.size
if w > h and self.scale_to_width < self.scale_to_height:
# throw error, they should match
raise ValueError(
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
elif h > w and self.scale_to_height < self.scale_to_width:
# throw error, they should match
raise ValueError(
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
if self.flip_x:
# do a flip
img.transpose(Image.FLIP_LEFT_RIGHT)
if self.flip_y:
# do a flip
img.transpose(Image.FLIP_TOP_BOTTOM)
if self.dataset_config.buckets:
# scale and crop based on file item
img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC)
# img = transforms.CenterCrop((self.crop_height, self.crop_width))(img)
# crop
img = img.crop((
self.crop_x,
self.crop_y,
self.crop_x + self.crop_width,
self.crop_y + self.crop_height
))
else:
raise Exception("Unconditional images are not supported for non-bucket datasets")
self.unconditional_tensor = self.unconditional_transforms(img)
def cleanup_unconditional(self: 'FileItemDTO'):
self.unconditional_tensor = None
self.unconditional_latent = None
class PoiFileItemDTOMixin:
# Point of interest bounding box. Allows for dynamic cropping without cropping out the main subject
# items in the poi will always be inside the image when random cropping