mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-26 15:23:57 +00:00
Small tweaks and fixes for specialized ip adapter training
This commit is contained in:
@@ -1211,7 +1211,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
clip_images,
|
||||
drop=True,
|
||||
is_training=True,
|
||||
has_been_preprocessed=True,
|
||||
has_been_preprocessed=False,
|
||||
quad_count=quad_count
|
||||
)
|
||||
if self.train_config.do_cfg:
|
||||
@@ -1222,7 +1222,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
).detach(),
|
||||
is_training=True,
|
||||
drop=True,
|
||||
has_been_preprocessed=True,
|
||||
has_been_preprocessed=False,
|
||||
quad_count=quad_count
|
||||
)
|
||||
elif has_clip_image:
|
||||
@@ -1230,14 +1230,14 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
clip_images.detach().to(self.device_torch, dtype=dtype),
|
||||
is_training=True,
|
||||
has_been_preprocessed=True,
|
||||
quad_count=quad_count
|
||||
quad_count=quad_count,
|
||||
# do cfg on clip embeds to normalize the embeddings for when doing cfg
|
||||
# cfg_embed_strength=3.0 if not self.train_config.do_cfg else None
|
||||
# cfg_embed_strength=3.0 if not self.train_config.do_cfg else None
|
||||
)
|
||||
if self.train_config.do_cfg:
|
||||
unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
|
||||
torch.zeros(
|
||||
(noisy_latents.shape[0], 3, image_size, image_size),
|
||||
device=self.device_torch, dtype=dtype
|
||||
).detach(),
|
||||
clip_images.detach().to(self.device_torch, dtype=dtype),
|
||||
is_training=True,
|
||||
drop=True,
|
||||
has_been_preprocessed=True,
|
||||
|
||||
@@ -352,6 +352,9 @@ class ModelConfig:
|
||||
if self.is_vega:
|
||||
self.is_xl = True
|
||||
|
||||
# for text encoder quant. Only works with pixart currently
|
||||
self.text_encoder_bits = kwargs.get('text_encoder_bits', 8) # 16, 8, 4
|
||||
|
||||
|
||||
class ReferenceDatasetConfig:
|
||||
def __init__(self, **kwargs):
|
||||
|
||||
@@ -728,6 +728,12 @@ class ClipImageFileItemDTOMixin:
|
||||
# do a flip
|
||||
img = img.transpose(Image.FLIP_TOP_BOTTOM)
|
||||
|
||||
# image must be square. If it is not, we will resize/squish it so it is, that way we don't crop out data
|
||||
if img.width != img.height:
|
||||
# resize to the smallest dimension
|
||||
min_size = min(img.width, img.height)
|
||||
img = img.resize((min_size, min_size), Image.BICUBIC)
|
||||
|
||||
if self.has_clip_augmentations:
|
||||
self.clip_image_tensor = self.augment_clip_image(img, transform=None)
|
||||
else:
|
||||
|
||||
@@ -14,6 +14,7 @@ from toolkit.models.zipper_resampler import ZipperResampler
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
from toolkit.saving import load_ip_adapter_model
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
from toolkit.util.inverse_cfg import inverse_classifier_guidance
|
||||
|
||||
sys.path.append(REPOS_ROOT)
|
||||
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional
|
||||
@@ -376,8 +377,10 @@ class IPAdapter(torch.nn.Module):
|
||||
output_dim = sd.unet.config['cross_attention_dim']
|
||||
|
||||
if is_pixart:
|
||||
heads = 20
|
||||
dim = 4096
|
||||
# heads = 20
|
||||
heads = 12
|
||||
# dim = 4096
|
||||
dim = 1280
|
||||
output_dim = 4096
|
||||
|
||||
if self.config.image_encoder_arch.startswith('convnext'):
|
||||
@@ -628,6 +631,23 @@ class IPAdapter(torch.nn.Module):
|
||||
clip_image_embeds = clip_image_embeds.to(device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach()
|
||||
return clip_image_embeds
|
||||
|
||||
def get_empty_clip_image(self, batch_size: int) -> torch.Tensor:
|
||||
with torch.no_grad():
|
||||
tensors_0_1 = torch.rand([batch_size, 3, self.input_size, self.input_size], device=self.device)
|
||||
noise_scale = torch.rand([tensors_0_1.shape[0], 1, 1, 1], device=self.device,
|
||||
dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
tensors_0_1 = tensors_0_1 * noise_scale
|
||||
# tensors_0_1 = tensors_0_1 * 0
|
||||
mean = torch.tensor(self.clip_image_processor.image_mean).to(
|
||||
self.device, dtype=get_torch_dtype(self.sd_ref().dtype)
|
||||
).detach()
|
||||
std = torch.tensor(self.clip_image_processor.image_std).to(
|
||||
self.device, dtype=get_torch_dtype(self.sd_ref().dtype)
|
||||
).detach()
|
||||
tensors_0_1 = torch.clip((255. * tensors_0_1), 0, 255).round() / 255.0
|
||||
clip_image = (tensors_0_1 - mean.view([1, 3, 1, 1])) / std.view([1, 3, 1, 1])
|
||||
return clip_image.detach()
|
||||
|
||||
def get_clip_image_embeds_from_tensors(
|
||||
self,
|
||||
tensors_0_1: torch.Tensor,
|
||||
@@ -635,6 +655,7 @@ class IPAdapter(torch.nn.Module):
|
||||
is_training=False,
|
||||
has_been_preprocessed=False,
|
||||
quad_count=4,
|
||||
cfg_embed_strength=None, # perform CFG on embeds with unconditional as negative
|
||||
) -> torch.Tensor:
|
||||
if self.sd_ref().unet.device != self.device:
|
||||
self.to(self.sd_ref().unet.device)
|
||||
@@ -642,6 +663,7 @@ class IPAdapter(torch.nn.Module):
|
||||
self.to(self.sd_ref().unet.device)
|
||||
if not self.config.train:
|
||||
is_training = False
|
||||
uncond_clip = None
|
||||
with torch.no_grad():
|
||||
# on training the clip image is created in the dataloader
|
||||
if not has_been_preprocessed:
|
||||
@@ -749,6 +771,48 @@ class IPAdapter(torch.nn.Module):
|
||||
# rearrange to (batch, tokens, size)
|
||||
clip_image_embeds = clip_image_embeds.permute(0, 2, 1)
|
||||
|
||||
# apply unconditional if doing cfg on embeds
|
||||
with torch.no_grad():
|
||||
if cfg_embed_strength is not None:
|
||||
uncond_clip = self.get_empty_clip_image(tensors_0_1.shape[0]).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
if self.config.quad_image:
|
||||
# split the 4x4 grid and stack on batch
|
||||
ci1, ci2 = uncond_clip.chunk(2, dim=2)
|
||||
ci1, ci3 = ci1.chunk(2, dim=3)
|
||||
ci2, ci4 = ci2.chunk(2, dim=3)
|
||||
to_cat = []
|
||||
for i, ci in enumerate([ci1, ci2, ci3, ci4]):
|
||||
if i < quad_count:
|
||||
to_cat.append(ci)
|
||||
else:
|
||||
break
|
||||
|
||||
uncond_clip = torch.cat(to_cat, dim=0).detach()
|
||||
uncond_clip_output = self.image_encoder(
|
||||
uncond_clip, output_hidden_states=True
|
||||
)
|
||||
|
||||
if self.config.clip_layer == 'penultimate_hidden_states':
|
||||
uncond_clip_output_embeds = uncond_clip_output.hidden_states[-2]
|
||||
elif self.config.clip_layer == 'last_hidden_state':
|
||||
uncond_clip_output_embeds = uncond_clip_output.hidden_states[-1]
|
||||
else:
|
||||
uncond_clip_output_embeds = uncond_clip_output.image_embeds
|
||||
if self.config.adapter_type == "clip_face":
|
||||
l2_norm = torch.norm(uncond_clip_output_embeds, p=2)
|
||||
uncond_clip_output_embeds = uncond_clip_output_embeds / l2_norm
|
||||
|
||||
uncond_clip_output_embeds = uncond_clip_output_embeds.detach()
|
||||
|
||||
|
||||
# apply inverse cfg
|
||||
clip_image_embeds = inverse_classifier_guidance(
|
||||
clip_image_embeds,
|
||||
uncond_clip_output_embeds,
|
||||
cfg_embed_strength
|
||||
)
|
||||
|
||||
|
||||
if self.config.quad_image:
|
||||
# get the outputs of the quat
|
||||
chunks = clip_image_embeds.chunk(quad_count, dim=0)
|
||||
|
||||
@@ -48,6 +48,7 @@ from diffusers import PixArtAlphaPipeline, DPMSolverMultistepScheduler
|
||||
from transformers import T5EncoderModel
|
||||
|
||||
from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT
|
||||
from toolkit.util.inverse_cfg import inverse_classifier_guidance
|
||||
|
||||
# tell it to shut up
|
||||
diffusers.logging.set_verbosity(diffusers.logging.ERROR)
|
||||
@@ -234,13 +235,20 @@ class StableDiffusion:
|
||||
flush()
|
||||
print("Injecting alt weights")
|
||||
elif self.model_config.is_pixart:
|
||||
te_kwargs = {}
|
||||
# handle quantization of TE
|
||||
if self.model_config.text_encoder_bits == 8:
|
||||
te_kwargs['load_in_8bit'] = True
|
||||
te_kwargs['device_map'] = "auto"
|
||||
elif self.model_config.text_encoder_bits == 4:
|
||||
te_kwargs['load_in_4bit'] = True
|
||||
te_kwargs['device_map'] = "auto"
|
||||
# load the TE in 8bit mode
|
||||
text_encoder = T5EncoderModel.from_pretrained(
|
||||
"PixArt-alpha/PixArt-XL-2-1024-MS",
|
||||
subfolder="text_encoder",
|
||||
load_in_8bit=True,
|
||||
device_map="auto",
|
||||
torch_dtype=self.torch_dtype,
|
||||
**te_kwargs
|
||||
)
|
||||
|
||||
# load the transformer
|
||||
|
||||
25
toolkit/util/inverse_cfg.py
Normal file
25
toolkit/util/inverse_cfg.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import torch
|
||||
|
||||
|
||||
def inverse_classifier_guidance(
|
||||
noise_pred_cond: torch.Tensor,
|
||||
noise_pred_uncond: torch.Tensor,
|
||||
guidance_scale: torch.Tensor
|
||||
):
|
||||
"""
|
||||
Adjust the noise_pred_cond for the classifier free guidance algorithm
|
||||
to ensure that the final noise prediction equals the original noise_pred_cond.
|
||||
"""
|
||||
# To make noise_pred equal noise_pred_cond_orig, we adjust noise_pred_cond
|
||||
# based on the formula used in the algorithm.
|
||||
# We derive the formula to find the correct adjustment for noise_pred_cond:
|
||||
# noise_pred_cond = (noise_pred_cond_orig - noise_pred_uncond * guidance_scale) / (guidance_scale - 1)
|
||||
# It's important to check if guidance_scale is not 1 to avoid division by zero.
|
||||
if guidance_scale == 1:
|
||||
# If guidance_scale is 1, adjusting is not needed or possible in the same way,
|
||||
# since it would lead to division by zero. This also means the algorithm inherently
|
||||
# doesn't alter the noise_pred_cond in relation to noise_pred_uncond.
|
||||
# Thus, we return the original values, though this situation might need special handling.
|
||||
return noise_pred_cond
|
||||
adjusted_noise_pred_cond = (noise_pred_cond - noise_pred_uncond) / guidance_scale
|
||||
return adjusted_noise_pred_cond
|
||||
Reference in New Issue
Block a user