Small tweaks and fixes for specialized ip adapter training

This commit is contained in:
Jaret Burkett
2024-03-26 11:35:26 -06:00
parent 9c1cc9641e
commit 427847ac4c
6 changed files with 117 additions and 11 deletions

View File

@@ -1211,7 +1211,7 @@ class SDTrainer(BaseSDTrainProcess):
clip_images,
drop=True,
is_training=True,
has_been_preprocessed=True,
has_been_preprocessed=False,
quad_count=quad_count
)
if self.train_config.do_cfg:
@@ -1222,7 +1222,7 @@ class SDTrainer(BaseSDTrainProcess):
).detach(),
is_training=True,
drop=True,
has_been_preprocessed=True,
has_been_preprocessed=False,
quad_count=quad_count
)
elif has_clip_image:
@@ -1230,14 +1230,14 @@ class SDTrainer(BaseSDTrainProcess):
clip_images.detach().to(self.device_torch, dtype=dtype),
is_training=True,
has_been_preprocessed=True,
quad_count=quad_count
quad_count=quad_count,
# do cfg on clip embeds to normalize the embeddings for when doing cfg
# cfg_embed_strength=3.0 if not self.train_config.do_cfg else None
# cfg_embed_strength=3.0 if not self.train_config.do_cfg else None
)
if self.train_config.do_cfg:
unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
torch.zeros(
(noisy_latents.shape[0], 3, image_size, image_size),
device=self.device_torch, dtype=dtype
).detach(),
clip_images.detach().to(self.device_torch, dtype=dtype),
is_training=True,
drop=True,
has_been_preprocessed=True,

View File

@@ -352,6 +352,9 @@ class ModelConfig:
if self.is_vega:
self.is_xl = True
# for text encoder quant. Only works with pixart currently
self.text_encoder_bits = kwargs.get('text_encoder_bits', 8) # 16, 8, 4
class ReferenceDatasetConfig:
def __init__(self, **kwargs):

View File

@@ -728,6 +728,12 @@ class ClipImageFileItemDTOMixin:
# do a flip
img = img.transpose(Image.FLIP_TOP_BOTTOM)
# image must be square. If it is not, we will resize/squish it so it is, that way we don't crop out data
if img.width != img.height:
# resize to the smallest dimension
min_size = min(img.width, img.height)
img = img.resize((min_size, min_size), Image.BICUBIC)
if self.has_clip_augmentations:
self.clip_image_tensor = self.augment_clip_image(img, transform=None)
else:

View File

@@ -14,6 +14,7 @@ from toolkit.models.zipper_resampler import ZipperResampler
from toolkit.paths import REPOS_ROOT
from toolkit.saving import load_ip_adapter_model
from toolkit.train_tools import get_torch_dtype
from toolkit.util.inverse_cfg import inverse_classifier_guidance
sys.path.append(REPOS_ROOT)
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional
@@ -376,8 +377,10 @@ class IPAdapter(torch.nn.Module):
output_dim = sd.unet.config['cross_attention_dim']
if is_pixart:
heads = 20
dim = 4096
# heads = 20
heads = 12
# dim = 4096
dim = 1280
output_dim = 4096
if self.config.image_encoder_arch.startswith('convnext'):
@@ -628,6 +631,23 @@ class IPAdapter(torch.nn.Module):
clip_image_embeds = clip_image_embeds.to(device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach()
return clip_image_embeds
def get_empty_clip_image(self, batch_size: int) -> torch.Tensor:
with torch.no_grad():
tensors_0_1 = torch.rand([batch_size, 3, self.input_size, self.input_size], device=self.device)
noise_scale = torch.rand([tensors_0_1.shape[0], 1, 1, 1], device=self.device,
dtype=get_torch_dtype(self.sd_ref().dtype))
tensors_0_1 = tensors_0_1 * noise_scale
# tensors_0_1 = tensors_0_1 * 0
mean = torch.tensor(self.clip_image_processor.image_mean).to(
self.device, dtype=get_torch_dtype(self.sd_ref().dtype)
).detach()
std = torch.tensor(self.clip_image_processor.image_std).to(
self.device, dtype=get_torch_dtype(self.sd_ref().dtype)
).detach()
tensors_0_1 = torch.clip((255. * tensors_0_1), 0, 255).round() / 255.0
clip_image = (tensors_0_1 - mean.view([1, 3, 1, 1])) / std.view([1, 3, 1, 1])
return clip_image.detach()
def get_clip_image_embeds_from_tensors(
self,
tensors_0_1: torch.Tensor,
@@ -635,6 +655,7 @@ class IPAdapter(torch.nn.Module):
is_training=False,
has_been_preprocessed=False,
quad_count=4,
cfg_embed_strength=None, # perform CFG on embeds with unconditional as negative
) -> torch.Tensor:
if self.sd_ref().unet.device != self.device:
self.to(self.sd_ref().unet.device)
@@ -642,6 +663,7 @@ class IPAdapter(torch.nn.Module):
self.to(self.sd_ref().unet.device)
if not self.config.train:
is_training = False
uncond_clip = None
with torch.no_grad():
# on training the clip image is created in the dataloader
if not has_been_preprocessed:
@@ -749,6 +771,48 @@ class IPAdapter(torch.nn.Module):
# rearrange to (batch, tokens, size)
clip_image_embeds = clip_image_embeds.permute(0, 2, 1)
# apply unconditional if doing cfg on embeds
with torch.no_grad():
if cfg_embed_strength is not None:
uncond_clip = self.get_empty_clip_image(tensors_0_1.shape[0]).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
if self.config.quad_image:
# split the 4x4 grid and stack on batch
ci1, ci2 = uncond_clip.chunk(2, dim=2)
ci1, ci3 = ci1.chunk(2, dim=3)
ci2, ci4 = ci2.chunk(2, dim=3)
to_cat = []
for i, ci in enumerate([ci1, ci2, ci3, ci4]):
if i < quad_count:
to_cat.append(ci)
else:
break
uncond_clip = torch.cat(to_cat, dim=0).detach()
uncond_clip_output = self.image_encoder(
uncond_clip, output_hidden_states=True
)
if self.config.clip_layer == 'penultimate_hidden_states':
uncond_clip_output_embeds = uncond_clip_output.hidden_states[-2]
elif self.config.clip_layer == 'last_hidden_state':
uncond_clip_output_embeds = uncond_clip_output.hidden_states[-1]
else:
uncond_clip_output_embeds = uncond_clip_output.image_embeds
if self.config.adapter_type == "clip_face":
l2_norm = torch.norm(uncond_clip_output_embeds, p=2)
uncond_clip_output_embeds = uncond_clip_output_embeds / l2_norm
uncond_clip_output_embeds = uncond_clip_output_embeds.detach()
# apply inverse cfg
clip_image_embeds = inverse_classifier_guidance(
clip_image_embeds,
uncond_clip_output_embeds,
cfg_embed_strength
)
if self.config.quad_image:
# get the outputs of the quat
chunks = clip_image_embeds.chunk(quad_count, dim=0)

View File

@@ -48,6 +48,7 @@ from diffusers import PixArtAlphaPipeline, DPMSolverMultistepScheduler
from transformers import T5EncoderModel
from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT
from toolkit.util.inverse_cfg import inverse_classifier_guidance
# tell it to shut up
diffusers.logging.set_verbosity(diffusers.logging.ERROR)
@@ -234,13 +235,20 @@ class StableDiffusion:
flush()
print("Injecting alt weights")
elif self.model_config.is_pixart:
te_kwargs = {}
# handle quantization of TE
if self.model_config.text_encoder_bits == 8:
te_kwargs['load_in_8bit'] = True
te_kwargs['device_map'] = "auto"
elif self.model_config.text_encoder_bits == 4:
te_kwargs['load_in_4bit'] = True
te_kwargs['device_map'] = "auto"
# load the TE in 8bit mode
text_encoder = T5EncoderModel.from_pretrained(
"PixArt-alpha/PixArt-XL-2-1024-MS",
subfolder="text_encoder",
load_in_8bit=True,
device_map="auto",
torch_dtype=self.torch_dtype,
**te_kwargs
)
# load the transformer

View File

@@ -0,0 +1,25 @@
import torch
def inverse_classifier_guidance(
noise_pred_cond: torch.Tensor,
noise_pred_uncond: torch.Tensor,
guidance_scale: torch.Tensor
):
"""
Adjust the noise_pred_cond for the classifier free guidance algorithm
to ensure that the final noise prediction equals the original noise_pred_cond.
"""
# To make noise_pred equal noise_pred_cond_orig, we adjust noise_pred_cond
# based on the formula used in the algorithm.
# We derive the formula to find the correct adjustment for noise_pred_cond:
# noise_pred_cond = (noise_pred_cond_orig - noise_pred_uncond * guidance_scale) / (guidance_scale - 1)
# It's important to check if guidance_scale is not 1 to avoid division by zero.
if guidance_scale == 1:
# If guidance_scale is 1, adjusting is not needed or possible in the same way,
# since it would lead to division by zero. This also means the algorithm inherently
# doesn't alter the noise_pred_cond in relation to noise_pred_uncond.
# Thus, we return the original values, though this situation might need special handling.
return noise_pred_cond
adjusted_noise_pred_cond = (noise_pred_cond - noise_pred_uncond) / guidance_scale
return adjusted_noise_pred_cond