mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-06 05:29:57 +00:00
Bug fixes. added ability to use l1 loss. varous other tests and improvements
This commit is contained in:
@@ -6,10 +6,14 @@ import numpy as np
|
||||
from diffusers import T2IAdapter, AutoencoderTiny
|
||||
|
||||
import torch.functional as F
|
||||
from safetensors.torch import load_file
|
||||
from torch.utils.data import DataLoader, ConcatDataset
|
||||
|
||||
from toolkit import train_tools
|
||||
from toolkit.basic import value_map, adain, get_mean_std
|
||||
from toolkit.clip_vision_adapter import ClipVisionAdapter
|
||||
from toolkit.config_modules import GuidanceConfig
|
||||
from toolkit.data_loader import get_dataloader_datasets
|
||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO, FileItemDTO
|
||||
from toolkit.guidance import get_targeted_guidance_loss, get_guidance_loss
|
||||
from toolkit.image_utils import show_tensors, show_latents
|
||||
@@ -46,6 +50,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.do_guided_loss = False
|
||||
self.taesd: Optional[AutoencoderTiny] = None
|
||||
|
||||
self._clip_image_embeds_unconditional: Union[List[str], None] = None
|
||||
|
||||
def before_model_load(self):
|
||||
pass
|
||||
|
||||
@@ -86,6 +92,22 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if self.adapter is not None:
|
||||
self.adapter.to(self.device_torch)
|
||||
|
||||
# check if we have regs and using adapter and caching clip embeddings
|
||||
has_reg = self.datasets_reg is not None and len(self.datasets_reg) > 0
|
||||
is_caching_clip_embeddings = any([self.datasets[i].cache_clip_vision_to_disk for i in range(len(self.datasets))])
|
||||
|
||||
if has_reg and is_caching_clip_embeddings:
|
||||
# we need a list of unconditional clip image embeds from other datasets to handle regs
|
||||
unconditional_clip_image_embeds = []
|
||||
datasets = get_dataloader_datasets(self.data_loader)
|
||||
for i in range(len(datasets)):
|
||||
unconditional_clip_image_embeds += datasets[i].clip_vision_unconditional_cache
|
||||
|
||||
if len(unconditional_clip_image_embeds) == 0:
|
||||
raise ValueError("No unconditional clip image embeds found. This should not happen")
|
||||
|
||||
self._clip_image_embeds_unconditional = unconditional_clip_image_embeds
|
||||
|
||||
def process_output_for_turbo(self, pred, noisy_latents, timesteps, noise, batch):
|
||||
# to process turbo learning, we make one big step from our current timestep to the end
|
||||
# we then denoise the prediction on that remaining step and target our loss to our target latents
|
||||
@@ -190,6 +212,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
**kwargs
|
||||
):
|
||||
loss_target = self.train_config.loss_target
|
||||
is_reg = any(batch.get_is_reg_list())
|
||||
|
||||
prior_mask_multiplier = None
|
||||
target_mask_multiplier = None
|
||||
@@ -202,6 +225,40 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
noise_pred_norm = torch.linalg.vector_norm(noise_pred, ord=2, dim=(1, 2, 3), keepdim=True)
|
||||
noise_pred = noise_pred * (noise_norm / noise_pred_norm)
|
||||
|
||||
if self.train_config.correct_pred_norm and not is_reg:
|
||||
with torch.no_grad():
|
||||
|
||||
# adjust the noise target in the opposite direction of the noise pred mean and std offset
|
||||
# this will apply additional force the model to correct itself to match the norm of the noise
|
||||
noise_pred_mean, noise_pred_std = get_mean_std(noise_pred)
|
||||
noise_mean, noise_std = get_mean_std(noise)
|
||||
|
||||
# apply the inverse offset of the mean and std to the noise
|
||||
noise_additional_mean = noise_mean - noise_pred_mean
|
||||
noise_additional_std = noise_std - noise_pred_std
|
||||
|
||||
# adjust for multiplier
|
||||
noise_additional_mean = noise_additional_mean * self.train_config.correct_pred_norm_multiplier
|
||||
noise_additional_std = noise_additional_std * self.train_config.correct_pred_norm_multiplier
|
||||
|
||||
noise_target_std = noise_std + noise_additional_std
|
||||
noise_target_mean = noise_mean + noise_additional_mean
|
||||
|
||||
|
||||
noise_pred_target_std = noise_pred_std - noise_additional_std
|
||||
noise_pred_target_mean = noise_pred_mean - noise_additional_mean
|
||||
noise_pred_target_std = noise_pred_target_std.detach()
|
||||
noise_pred_target_mean = noise_pred_target_mean.detach()
|
||||
|
||||
# match the noise to the target
|
||||
noise = (noise - noise_mean) / noise_std
|
||||
noise = noise * noise_target_std + noise_target_mean
|
||||
noise = noise.detach()
|
||||
|
||||
# meatch the noise pred to the target
|
||||
# noise_pred = (noise_pred - noise_pred_mean) / noise_pred_std
|
||||
# noise_pred = noise_pred * noise_pred_target_std + noise_pred_target_mean
|
||||
|
||||
if self.train_config.inverted_mask_prior and prior_pred is not None and has_mask:
|
||||
assert not self.train_config.train_turbo
|
||||
# we need to make the noise prediction be a masked blending of noise and prior_pred
|
||||
@@ -227,7 +284,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
target = prior_pred
|
||||
elif self.sd.prediction_type == 'v_prediction':
|
||||
# v-parameterization training
|
||||
target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
|
||||
target = self.sd.noise_scheduler.get_velocity(batch.tensor, noise, timesteps)
|
||||
else:
|
||||
target = noise
|
||||
|
||||
@@ -270,7 +327,10 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
loss_per_element = (weighing.float() * (denoised_latents.float() - target.float()) ** 2)
|
||||
loss = loss_per_element
|
||||
else:
|
||||
loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none")
|
||||
if self.train_config.loss_type == "mae":
|
||||
loss = torch.nn.functional.l1_loss(pred.float(), target.float(), reduction="none")
|
||||
else:
|
||||
loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none")
|
||||
|
||||
# multiply by our mask
|
||||
loss = loss * mask_multiplier
|
||||
@@ -278,12 +338,11 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
prior_loss = None
|
||||
if self.train_config.inverted_mask_prior and prior_pred is not None and prior_mask_multiplier is not None:
|
||||
assert not self.train_config.train_turbo
|
||||
# to a loss to unmasked areas of the prior for unmasked regularization
|
||||
prior_loss = torch.nn.functional.mse_loss(
|
||||
prior_pred.float(),
|
||||
pred.float(),
|
||||
reduction="none"
|
||||
)
|
||||
if self.train_config.loss_type == "mae":
|
||||
prior_loss = torch.nn.functional.l1_loss(pred.float(), prior_pred.float(), reduction="none")
|
||||
else:
|
||||
prior_loss = torch.nn.functional.mse_loss(pred.float(), prior_pred.float(), reduction="none")
|
||||
|
||||
prior_loss = prior_loss * prior_mask_multiplier * self.train_config.inverted_mask_prior_multiplier
|
||||
if torch.isnan(prior_loss).any():
|
||||
print("Prior loss is nan")
|
||||
@@ -717,6 +776,13 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
has_adapter_img = batch.control_tensor is not None
|
||||
has_clip_image = batch.clip_image_tensor is not None
|
||||
has_clip_image_embeds = batch.clip_image_embeds is not None
|
||||
# force it to be true if doing regs as we handle those differently
|
||||
if any([batch.file_items[idx].is_reg for idx in range(len(batch.file_items))]):
|
||||
has_clip_image = True
|
||||
if self._clip_image_embeds_unconditional is not None:
|
||||
has_clip_image_embeds = True # we are caching embeds, handle that differently
|
||||
has_clip_image = False
|
||||
|
||||
if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not has_clip_image and has_adapter_img:
|
||||
raise ValueError(
|
||||
@@ -996,7 +1062,39 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# number of images to do if doing a quad image
|
||||
quad_count = random.randint(1, 4)
|
||||
image_size = self.adapter.input_size
|
||||
if is_reg:
|
||||
if has_clip_image_embeds:
|
||||
# todo handle reg images better than this
|
||||
if is_reg:
|
||||
# get unconditional image imbeds from cache
|
||||
embeds = [
|
||||
load_file(random.choice(batch.clip_image_embeds_unconditional)) for i in
|
||||
range(noisy_latents.shape[0])
|
||||
]
|
||||
conditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache(
|
||||
embeds,
|
||||
quad_count=quad_count
|
||||
)
|
||||
|
||||
if self.train_config.do_cfg:
|
||||
embeds = [
|
||||
load_file(random.choice(batch.clip_image_embeds_unconditional)) for i in range(noisy_latents.shape[0])
|
||||
]
|
||||
unconditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache(
|
||||
embeds,
|
||||
quad_count=quad_count
|
||||
)
|
||||
|
||||
else:
|
||||
conditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache(
|
||||
batch.clip_image_embeds,
|
||||
quad_count=quad_count
|
||||
)
|
||||
if self.train_config.do_cfg:
|
||||
unconditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache(
|
||||
batch.clip_image_embeds_unconditional,
|
||||
quad_count=quad_count
|
||||
)
|
||||
elif is_reg:
|
||||
# we will zero it out in the img embedder
|
||||
clip_images = torch.zeros(
|
||||
(noisy_latents.shape[0], 3, image_size, image_size),
|
||||
@@ -1071,9 +1169,9 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
prior_pred = None
|
||||
|
||||
do_reg_prior = False
|
||||
# if is_reg and (self.network is not None or self.adapter is not None):
|
||||
# # we are doing a reg image and we have a network or adapter
|
||||
# do_reg_prior = True
|
||||
if is_reg and (self.network is not None or self.adapter is not None):
|
||||
# we are doing a reg image and we have a network or adapter
|
||||
do_reg_prior = True
|
||||
|
||||
do_inverted_masked_prior = False
|
||||
if self.train_config.inverted_mask_prior and batch.mask_tensor is not None:
|
||||
@@ -1096,12 +1194,14 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
# do the custom adapter after the prior prediction
|
||||
if self.adapter and isinstance(self.adapter, CustomAdapter) and has_clip_image:
|
||||
quad_count = random.randint(1, 4)
|
||||
self.adapter.train()
|
||||
conditional_embeds = self.adapter.condition_encoded_embeds(
|
||||
tensors_0_1=clip_images,
|
||||
prompt_embeds=conditional_embeds,
|
||||
is_training=True,
|
||||
has_been_preprocessed=True,
|
||||
quad_count=quad_count
|
||||
)
|
||||
if self.train_config.do_cfg and unconditional_embeds is not None:
|
||||
unconditional_embeds = self.adapter.condition_encoded_embeds(
|
||||
@@ -1109,7 +1209,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
prompt_embeds=unconditional_embeds,
|
||||
is_training=True,
|
||||
has_been_preprocessed=True,
|
||||
is_unconditional=True
|
||||
is_unconditional=True,
|
||||
quad_count=quad_count
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -285,6 +285,13 @@ class TrainConfig:
|
||||
self.cfg_scale = kwargs.get('cfg_scale', 1.0)
|
||||
self.max_cfg_scale = kwargs.get('max_cfg_scale', self.cfg_scale)
|
||||
|
||||
# applies the inverse of the prediction mean and std to the target to correct
|
||||
# for norm drift
|
||||
self.correct_pred_norm = kwargs.get('correct_pred_norm', False)
|
||||
self.correct_pred_norm_multiplier = kwargs.get('correct_pred_norm_multiplier', 1.0)
|
||||
|
||||
self.loss_type = kwargs.get('loss_type', 'mse')
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
def __init__(self, **kwargs):
|
||||
@@ -444,6 +451,7 @@ class DatasetConfig:
|
||||
self.cache_latents: bool = kwargs.get('cache_latents', False)
|
||||
# cache latents to disk will store them on disk. If both are true, it will save to disk, but keep in memory
|
||||
self.cache_latents_to_disk: bool = kwargs.get('cache_latents_to_disk', False)
|
||||
self.cache_clip_vision_to_disk: bool = kwargs.get('cache_clip_vision_to_disk', False)
|
||||
|
||||
self.standardize_images: bool = kwargs.get('standardize_images', False)
|
||||
|
||||
|
||||
@@ -227,6 +227,16 @@ class CustomAdapter(torch.nn.Module):
|
||||
|
||||
self.input_size = self.vision_encoder.config.image_size
|
||||
|
||||
if self.config.quad_image: # 4x4 image
|
||||
# self.clip_image_processor.config
|
||||
# We do a 3x downscale of the image, so we need to adjust the input size
|
||||
preprocessor_input_size = self.vision_encoder.config.image_size * 2
|
||||
|
||||
# update the preprocessor so images come in at the right size
|
||||
self.image_processor.size['shortest_edge'] = preprocessor_input_size
|
||||
self.image_processor.crop_size['height'] = preprocessor_input_size
|
||||
self.image_processor.crop_size['width'] = preprocessor_input_size
|
||||
|
||||
if self.config.image_encoder_arch == 'clip+':
|
||||
# self.image_processor.config
|
||||
# We do a 3x downscale of the image, so we need to adjust the input size
|
||||
@@ -425,7 +435,8 @@ class CustomAdapter(torch.nn.Module):
|
||||
prompt_embeds: PromptEmbeds,
|
||||
is_training=False,
|
||||
has_been_preprocessed=False,
|
||||
is_unconditional=False
|
||||
is_unconditional=False,
|
||||
quad_count=4,
|
||||
) -> PromptEmbeds:
|
||||
if self.adapter_type == 'photo_maker' or self.adapter_type == 'clip_fusion' or self.adapter_type == 'ilora':
|
||||
if is_unconditional:
|
||||
@@ -454,6 +465,20 @@ class CustomAdapter(torch.nn.Module):
|
||||
clip_image = tensors_0_1
|
||||
clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach()
|
||||
|
||||
if self.config.quad_image:
|
||||
# split the 4x4 grid and stack on batch
|
||||
ci1, ci2 = clip_image.chunk(2, dim=2)
|
||||
ci1, ci3 = ci1.chunk(2, dim=3)
|
||||
ci2, ci4 = ci2.chunk(2, dim=3)
|
||||
to_cat = []
|
||||
for i, ci in enumerate([ci1, ci2, ci3, ci4]):
|
||||
if i < quad_count:
|
||||
to_cat.append(ci)
|
||||
else:
|
||||
break
|
||||
|
||||
clip_image = torch.cat(to_cat, dim=0).detach()
|
||||
|
||||
if self.adapter_type == 'photo_maker':
|
||||
# Embeddings need to be (b, num_inputs, c, h, w) for now, just put 1 input image
|
||||
clip_image = clip_image.unsqueeze(1)
|
||||
@@ -496,6 +521,17 @@ class CustomAdapter(torch.nn.Module):
|
||||
|
||||
img_embeds = id_embeds['last_hidden_state']
|
||||
|
||||
if self.config.quad_image:
|
||||
# get the outputs of the quat
|
||||
chunks = img_embeds.chunk(quad_count, dim=0)
|
||||
chunk_sum = torch.zeros_like(chunks[0])
|
||||
for chunk in chunks:
|
||||
chunk_sum = chunk_sum + chunk
|
||||
# get the mean of them
|
||||
|
||||
img_embeds = chunk_sum / quad_count
|
||||
|
||||
|
||||
if not is_training or not self.config.train_image_encoder:
|
||||
img_embeds = img_embeds.detach()
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ import albumentations as A
|
||||
|
||||
from toolkit.buckets import get_bucket_for_image_size, BucketResolution
|
||||
from toolkit.config_modules import DatasetConfig, preprocess_dataset_raw_config
|
||||
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin, Augments
|
||||
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin, Augments, CLIPCachingMixin
|
||||
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -355,7 +355,7 @@ class PairedImageDataset(Dataset):
|
||||
return img, prompt, (self.neg_weight, self.pos_weight)
|
||||
|
||||
|
||||
class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
|
||||
class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, CaptionMixin, Dataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -373,6 +373,7 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
|
||||
self.is_caching_latents = dataset_config.cache_latents or dataset_config.cache_latents_to_disk
|
||||
self.is_caching_latents_to_memory = dataset_config.cache_latents
|
||||
self.is_caching_latents_to_disk = dataset_config.cache_latents_to_disk
|
||||
self.is_caching_clip_vision_to_disk = dataset_config.cache_clip_vision_to_disk
|
||||
self.epoch_num = 0
|
||||
|
||||
self.sd = sd
|
||||
@@ -482,6 +483,8 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
|
||||
self.setup_buckets()
|
||||
if self.is_caching_latents:
|
||||
self.cache_latents_all_latents()
|
||||
if self.is_caching_clip_vision_to_disk:
|
||||
self.cache_clip_vision_to_disk()
|
||||
else:
|
||||
if self.dataset_config.poi is not None:
|
||||
# handle cropping to a specific point of interest
|
||||
@@ -611,3 +614,19 @@ def trigger_dataloader_setup_epoch(dataloader: DataLoader):
|
||||
if hasattr(sub_dataset, 'setup_epoch'):
|
||||
sub_dataset.setup_epoch()
|
||||
sub_dataset.len = None
|
||||
|
||||
def get_dataloader_datasets(dataloader: DataLoader):
|
||||
# hacky but needed because of different types of datasets and dataloaders
|
||||
if isinstance(dataloader.dataset, list):
|
||||
datasets = []
|
||||
for dataset in dataloader.dataset:
|
||||
if hasattr(dataset, 'datasets'):
|
||||
for sub_dataset in dataset.datasets:
|
||||
datasets.append(sub_dataset)
|
||||
else:
|
||||
datasets.append(dataset)
|
||||
return datasets
|
||||
elif hasattr(dataloader.dataset, 'datasets'):
|
||||
return dataloader.dataset.datasets
|
||||
else:
|
||||
return [dataloader.dataset]
|
||||
|
||||
@@ -96,6 +96,8 @@ class DataLoaderBatchDTO:
|
||||
self.unaugmented_tensor: Union[torch.Tensor, None] = None
|
||||
self.unconditional_tensor: Union[torch.Tensor, None] = None
|
||||
self.unconditional_latents: Union[torch.Tensor, None] = None
|
||||
self.clip_image_embeds: Union[List[dict], None] = None
|
||||
self.clip_image_embeds_unconditional: Union[List[dict], None] = None
|
||||
self.sigmas: Union[torch.Tensor, None] = None # can be added elseware and passed along training code
|
||||
if not is_latents_cached:
|
||||
# only return a tensor if latents are not cached
|
||||
@@ -183,6 +185,23 @@ class DataLoaderBatchDTO:
|
||||
else:
|
||||
unconditional_tensor.append(x.unconditional_tensor)
|
||||
self.unconditional_tensor = torch.cat([x.unsqueeze(0) for x in unconditional_tensor])
|
||||
|
||||
if any([x.clip_image_embeds is not None for x in self.file_items]):
|
||||
self.clip_image_embeds = []
|
||||
for x in self.file_items:
|
||||
if x.clip_image_embeds is not None:
|
||||
self.clip_image_embeds.append(x.clip_image_embeds)
|
||||
else:
|
||||
raise Exception("clip_image_embeds is None for some file items")
|
||||
|
||||
if any([x.clip_image_embeds_unconditional is not None for x in self.file_items]):
|
||||
self.clip_image_embeds_unconditional = []
|
||||
for x in self.file_items:
|
||||
if x.clip_image_embeds_unconditional is not None:
|
||||
self.clip_image_embeds_unconditional.append(x.clip_image_embeds_unconditional)
|
||||
else:
|
||||
raise Exception("clip_image_embeds_unconditional is None for some file items")
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise e
|
||||
|
||||
@@ -12,7 +12,7 @@ import numpy as np
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
from transformers import CLIPImageProcessor
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from toolkit.basic import flush, value_map
|
||||
from toolkit.buckets import get_bucket_for_image_size, get_resolution
|
||||
@@ -570,9 +570,18 @@ class ClipImageFileItemDTOMixin:
|
||||
self.has_clip_image = False
|
||||
self.clip_image_path: Union[str, None] = None
|
||||
self.clip_image_tensor: Union[torch.Tensor, None] = None
|
||||
self.clip_image_embeds: Union[dict, None] = None
|
||||
self.clip_image_embeds_unconditional: Union[dict, None] = None
|
||||
self.has_clip_augmentations = False
|
||||
self.clip_image_aug_transform: Union[None, A.Compose] = None
|
||||
self.clip_image_processor: Union[None, CLIPImageProcessor] = None
|
||||
self.clip_image_encoder_path: Union[str, None] = None
|
||||
self.is_caching_clip_vision_to_disk = False
|
||||
self.is_vision_clip_cached = False
|
||||
self.clip_vision_is_quad = False
|
||||
self.clip_vision_load_device = 'cpu'
|
||||
self.clip_vision_unconditional_paths: Union[List[str], None] = None
|
||||
self._clip_vision_embeddings_path: Union[str, None] = None
|
||||
dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
|
||||
if dataset_config.clip_image_path is not None:
|
||||
# copy the clip image processor so the dataloader can do it
|
||||
@@ -633,7 +642,45 @@ class ClipImageFileItemDTOMixin:
|
||||
|
||||
return augmented_tensor
|
||||
|
||||
def get_clip_vision_info_dict(self: 'FileItemDTO'):
|
||||
item = OrderedDict([
|
||||
("image_encoder_path", self.clip_image_encoder_path),
|
||||
("filename", os.path.basename(self.clip_image_path)),
|
||||
("is_quad", self.clip_vision_is_quad)
|
||||
])
|
||||
# when adding items, do it after so we dont change old latents
|
||||
if self.flip_x:
|
||||
item["flip_x"] = True
|
||||
if self.flip_y:
|
||||
item["flip_y"] = True
|
||||
return item
|
||||
def get_clip_vision_embeddings_path(self: 'FileItemDTO', recalculate=False):
|
||||
if self._clip_vision_embeddings_path is not None and not recalculate:
|
||||
return self._clip_vision_embeddings_path
|
||||
else:
|
||||
# we store latents in a folder in same path as image called _latent_cache
|
||||
img_dir = os.path.dirname(self.clip_image_path)
|
||||
latent_dir = os.path.join(img_dir, '_clip_vision_cache')
|
||||
hash_dict = self.get_clip_vision_info_dict()
|
||||
filename_no_ext = os.path.splitext(os.path.basename(self.clip_image_path))[0]
|
||||
# get base64 hash of md5 checksum of hash_dict
|
||||
hash_input = json.dumps(hash_dict, sort_keys=True).encode('utf-8')
|
||||
hash_str = base64.urlsafe_b64encode(hashlib.md5(hash_input).digest()).decode('ascii')
|
||||
hash_str = hash_str.replace('=', '')
|
||||
self._clip_vision_embeddings_path = os.path.join(latent_dir, f'{filename_no_ext}_{hash_str}.safetensors')
|
||||
|
||||
return self._clip_vision_embeddings_path
|
||||
|
||||
def load_clip_image(self: 'FileItemDTO'):
|
||||
if self.is_vision_clip_cached:
|
||||
self.clip_image_embeds = load_file(self.get_clip_vision_embeddings_path())
|
||||
|
||||
# get a random unconditional image
|
||||
if self.clip_vision_unconditional_paths is not None:
|
||||
unconditional_path = random.choice(self.clip_vision_unconditional_paths)
|
||||
self.clip_image_embeds_unconditional = load_file(unconditional_path)
|
||||
|
||||
return
|
||||
img = Image.open(self.clip_image_path).convert('RGB')
|
||||
try:
|
||||
img = exif_transpose(img)
|
||||
@@ -683,6 +730,7 @@ class ClipImageFileItemDTOMixin:
|
||||
|
||||
def cleanup_clip_image(self: 'FileItemDTO'):
|
||||
self.clip_image_tensor = None
|
||||
self.clip_image_embeds = None
|
||||
|
||||
|
||||
|
||||
@@ -1273,7 +1321,7 @@ class LatentCachingMixin:
|
||||
del latent
|
||||
del file_item.tensor
|
||||
|
||||
flush(garbage_collect=False)
|
||||
# flush(garbage_collect=False)
|
||||
file_item.is_latent_cached = True
|
||||
i += 1
|
||||
# flush every 100
|
||||
@@ -1282,3 +1330,176 @@ class LatentCachingMixin:
|
||||
|
||||
# restore device state
|
||||
self.sd.restore_device_state()
|
||||
|
||||
|
||||
class CLIPCachingMixin:
|
||||
def __init__(self: 'AiToolkitDataset', **kwargs):
|
||||
# if we have super, call it
|
||||
if hasattr(super(), '__init__'):
|
||||
super().__init__(**kwargs)
|
||||
self.clip_vision_num_unconditional_cache = 20
|
||||
self.clip_vision_unconditional_cache = []
|
||||
|
||||
def cache_clip_vision_to_disk(self: 'AiToolkitDataset'):
|
||||
if not self.is_caching_clip_vision_to_disk:
|
||||
return
|
||||
with torch.no_grad():
|
||||
print(f"Caching clip vision for {self.dataset_path}")
|
||||
|
||||
print(" - Saving clip to disk")
|
||||
# move sd items to cpu except for vae
|
||||
self.sd.set_device_state_preset('cache_clip')
|
||||
|
||||
# make sure the adapter has attributes
|
||||
if self.sd.adapter is None:
|
||||
raise Exception("Error: must have an adapter to cache clip vision to disk")
|
||||
|
||||
clip_image_processor: CLIPImageProcessor = None
|
||||
if hasattr(self.sd.adapter, 'clip_image_processor'):
|
||||
clip_image_processor = self.sd.adapter.clip_image_processor
|
||||
|
||||
if clip_image_processor is None:
|
||||
raise Exception("Error: must have a clip image processor to cache clip vision to disk")
|
||||
|
||||
vision_encoder: CLIPVisionModelWithProjection = None
|
||||
if hasattr(self.sd.adapter, 'image_encoder'):
|
||||
vision_encoder = self.sd.adapter.image_encoder
|
||||
if hasattr(self.sd.adapter, 'vision_encoder'):
|
||||
vision_encoder = self.sd.adapter.vision_encoder
|
||||
|
||||
if vision_encoder is None:
|
||||
raise Exception("Error: must have a vision encoder to cache clip vision to disk")
|
||||
|
||||
# move vision encoder to device
|
||||
vision_encoder.to(self.sd.device)
|
||||
|
||||
is_quad = self.sd.adapter.config.quad_image
|
||||
image_encoder_path = self.sd.adapter.config.image_encoder_path
|
||||
|
||||
dtype = self.sd.torch_dtype
|
||||
device = self.sd.device_torch
|
||||
if hasattr(self.sd.adapter, 'clip_noise_zero') and self.sd.adapter.clip_noise_zero:
|
||||
# just to do this, we did :)
|
||||
# need more samples as it is random noise
|
||||
self.clip_vision_num_unconditional_cache = self.clip_vision_num_unconditional_cache
|
||||
else:
|
||||
# only need one since it doesnt change
|
||||
self.clip_vision_num_unconditional_cache = 1
|
||||
|
||||
# cache unconditionals
|
||||
print(f" - Caching {self.clip_vision_num_unconditional_cache} unconditional clip vision to disk")
|
||||
clip_vision_cache_path = os.path.join(self.dataset_config.clip_image_path, '_clip_vision_cache')
|
||||
|
||||
unconditional_paths = []
|
||||
|
||||
is_noise_zero = hasattr(self.sd.adapter, 'clip_noise_zero') and self.sd.adapter.clip_noise_zero
|
||||
|
||||
for i in range(self.clip_vision_num_unconditional_cache):
|
||||
hash_dict = OrderedDict([
|
||||
("image_encoder_path", image_encoder_path),
|
||||
("is_quad", is_quad),
|
||||
("is_noise_zero", is_noise_zero),
|
||||
])
|
||||
# get base64 hash of md5 checksum of hash_dict
|
||||
hash_input = json.dumps(hash_dict, sort_keys=True).encode('utf-8')
|
||||
hash_str = base64.urlsafe_b64encode(hashlib.md5(hash_input).digest()).decode('ascii')
|
||||
hash_str = hash_str.replace('=', '')
|
||||
|
||||
uncond_path = os.path.join(clip_vision_cache_path, f'uncond_{hash_str}_{i}.safetensors')
|
||||
if os.path.exists(uncond_path):
|
||||
# skip it
|
||||
unconditional_paths.append(uncond_path)
|
||||
continue
|
||||
|
||||
# generate a random image
|
||||
img_shape = (1, 3, self.sd.adapter.input_size, self.sd.adapter.input_size)
|
||||
if is_noise_zero:
|
||||
tensors_0_1 = torch.rand(img_shape).to(device, dtype=torch.float32)
|
||||
else:
|
||||
tensors_0_1 = torch.zeros(img_shape).to(device, dtype=torch.float32)
|
||||
clip_image = clip_image_processor(
|
||||
images=tensors_0_1,
|
||||
return_tensors="pt",
|
||||
do_resize=True,
|
||||
do_rescale=False,
|
||||
).pixel_values
|
||||
|
||||
if is_quad:
|
||||
# split the 4x4 grid and stack on batch
|
||||
ci1, ci2 = clip_image.chunk(2, dim=2)
|
||||
ci1, ci3 = ci1.chunk(2, dim=3)
|
||||
ci2, ci4 = ci2.chunk(2, dim=3)
|
||||
clip_image = torch.cat([ci1, ci2, ci3, ci4], dim=0).detach()
|
||||
|
||||
clip_output = vision_encoder(
|
||||
clip_image.to(device, dtype=dtype),
|
||||
output_hidden_states=True
|
||||
)
|
||||
# make state_dict ['last_hidden_state', 'image_embeds', 'penultimate_hidden_states']
|
||||
state_dict = OrderedDict([
|
||||
('image_embeds', clip_output.image_embeds.clone().detach().cpu()),
|
||||
('last_hidden_state', clip_output.hidden_states[-1].clone().detach().cpu()),
|
||||
('penultimate_hidden_states', clip_output.hidden_states[-2].clone().detach().cpu()),
|
||||
])
|
||||
|
||||
os.makedirs(os.path.dirname(uncond_path), exist_ok=True)
|
||||
save_file(state_dict, uncond_path)
|
||||
unconditional_paths.append(uncond_path)
|
||||
|
||||
self.clip_vision_unconditional_cache = unconditional_paths
|
||||
|
||||
# use tqdm to show progress
|
||||
i = 0
|
||||
for file_item in tqdm(self.file_list, desc=f'Caching clip vision to disk'):
|
||||
file_item.is_caching_clip_vision_to_disk = True
|
||||
file_item.clip_vision_load_device = self.sd.device
|
||||
file_item.clip_vision_is_quad = is_quad
|
||||
file_item.clip_image_encoder_path = image_encoder_path
|
||||
file_item.clip_vision_unconditional_paths = unconditional_paths
|
||||
if file_item.has_clip_augmentations:
|
||||
raise Exception("Error: clip vision caching is not supported with clip augmentations")
|
||||
|
||||
embedding_path = file_item.get_clip_vision_embeddings_path(recalculate=True)
|
||||
# check if it is saved to disk already
|
||||
if not os.path.exists(embedding_path):
|
||||
# load the image first
|
||||
file_item.load_clip_image()
|
||||
# add batch dimension
|
||||
clip_image = file_item.clip_image_tensor.unsqueeze(0).to(device, dtype=dtype)
|
||||
|
||||
if is_quad:
|
||||
# split the 4x4 grid and stack on batch
|
||||
ci1, ci2 = clip_image.chunk(2, dim=2)
|
||||
ci1, ci3 = ci1.chunk(2, dim=3)
|
||||
ci2, ci4 = ci2.chunk(2, dim=3)
|
||||
clip_image = torch.cat([ci1, ci2, ci3, ci4], dim=0).detach()
|
||||
|
||||
clip_output = vision_encoder(
|
||||
clip_image.to(device, dtype=dtype),
|
||||
output_hidden_states=True
|
||||
)
|
||||
|
||||
# make state_dict ['last_hidden_state', 'image_embeds', 'penultimate_hidden_states']
|
||||
state_dict = OrderedDict([
|
||||
('image_embeds', clip_output.image_embeds.clone().detach().cpu()),
|
||||
('last_hidden_state', clip_output.hidden_states[-1].clone().detach().cpu()),
|
||||
('penultimate_hidden_states', clip_output.hidden_states[-2].clone().detach().cpu()),
|
||||
])
|
||||
# metadata
|
||||
meta = get_meta_for_safetensors(file_item.get_clip_vision_info_dict())
|
||||
os.makedirs(os.path.dirname(embedding_path), exist_ok=True)
|
||||
save_file(state_dict, embedding_path, metadata=meta)
|
||||
|
||||
del clip_image
|
||||
del clip_output
|
||||
del file_item.clip_image_tensor
|
||||
|
||||
# flush(garbage_collect=False)
|
||||
file_item.is_vision_clip_cached = True
|
||||
i += 1
|
||||
# flush every 100
|
||||
# if i % 100 == 0:
|
||||
# flush()
|
||||
|
||||
# restore device state
|
||||
self.sd.restore_device_state()
|
||||
|
||||
@@ -249,9 +249,13 @@ class IPAdapter(torch.nn.Module):
|
||||
preprocessor_input_size = self.image_encoder.config.image_size * 2
|
||||
|
||||
# update the preprocessor so images come in at the right size
|
||||
self.clip_image_processor.size['shortest_edge'] = preprocessor_input_size
|
||||
self.clip_image_processor.crop_size['height'] = preprocessor_input_size
|
||||
self.clip_image_processor.crop_size['width'] = preprocessor_input_size
|
||||
if 'height' in self.clip_image_processor.size:
|
||||
self.clip_image_processor.size['height'] = preprocessor_input_size
|
||||
self.clip_image_processor.size['width'] = preprocessor_input_size
|
||||
elif hasattr(self.clip_image_processor, 'crop_size'):
|
||||
self.clip_image_processor.size['shortest_edge'] = preprocessor_input_size
|
||||
self.clip_image_processor.crop_size['height'] = preprocessor_input_size
|
||||
self.clip_image_processor.crop_size['width'] = preprocessor_input_size
|
||||
|
||||
if self.config.image_encoder_arch == 'clip+':
|
||||
# self.clip_image_processor.config
|
||||
@@ -439,6 +443,32 @@ class IPAdapter(torch.nn.Module):
|
||||
if self.preprocessor is not None:
|
||||
self.preprocessor.to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
def parse_clip_image_embeds_from_cache(
|
||||
self,
|
||||
image_embeds_list: List[dict], # has ['last_hidden_state', 'image_embeds', 'penultimate_hidden_states']
|
||||
quad_count=4,
|
||||
):
|
||||
with torch.no_grad():
|
||||
device = self.sd_ref().unet.device
|
||||
if self.config.type.startswith('ip+'):
|
||||
clip_image_embeds = torch.cat([x['penultimate_hidden_states'] for x in image_embeds_list], dim=0)
|
||||
else:
|
||||
clip_image_embeds = torch.cat([x['image_embeds'] for x in image_embeds_list], dim=0)
|
||||
|
||||
if self.config.quad_image:
|
||||
# get the outputs of the quat
|
||||
chunks = clip_image_embeds.chunk(quad_count, dim=0)
|
||||
chunk_sum = torch.zeros_like(chunks[0])
|
||||
for chunk in chunks:
|
||||
chunk_sum = chunk_sum + chunk
|
||||
# get the mean of them
|
||||
|
||||
clip_image_embeds = chunk_sum / quad_count
|
||||
|
||||
clip_image_embeds = clip_image_embeds.to(device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach()
|
||||
return clip_image_embeds
|
||||
|
||||
def get_clip_image_embeds_from_tensors(
|
||||
self,
|
||||
tensors_0_1: torch.Tensor,
|
||||
|
||||
@@ -172,6 +172,8 @@ class CLIPFusionModule(nn.Module):
|
||||
dim=self.text_hidden_size,
|
||||
)
|
||||
|
||||
self.alpha = nn.Parameter(torch.zeros([text_tokens]) + 0.01)
|
||||
|
||||
def forward(self, text_embeds, vision_embeds):
|
||||
# text_embeds = (batch_size, 77, 768)
|
||||
# vision_embeds = (batch_size, 257, 1024)
|
||||
@@ -186,7 +188,12 @@ class CLIPFusionModule(nn.Module):
|
||||
x = x + res
|
||||
|
||||
# alpha mask
|
||||
alpha = self.ctx_alpha(text_embeds)
|
||||
x = alpha * x + (1 - alpha) * text_embeds
|
||||
ctx_alpha = self.ctx_alpha(text_embeds)
|
||||
# reshape alpha to (1, 77, 1)
|
||||
alpha = self.alpha.unsqueeze(0).unsqueeze(-1)
|
||||
|
||||
x = ctx_alpha * x * alpha
|
||||
|
||||
x = x + text_embeds
|
||||
|
||||
return x
|
||||
|
||||
@@ -841,6 +841,14 @@ class StableDiffusion:
|
||||
else:
|
||||
timestep = timestep.repeat(latents.shape[0], 0)
|
||||
|
||||
|
||||
# handle t2i adapters
|
||||
if 'down_intrablock_additional_residuals' in kwargs:
|
||||
# go through each item and concat if doing cfg and it doesnt have the same shape
|
||||
for idx, item in enumerate(kwargs['down_intrablock_additional_residuals']):
|
||||
if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]:
|
||||
kwargs['down_intrablock_additional_residuals'][idx] = torch.cat([item] * 2, dim=0)
|
||||
|
||||
def scale_model_input(model_input, timestep_tensor):
|
||||
if is_input_scaled:
|
||||
return model_input
|
||||
@@ -1599,6 +1607,8 @@ class StableDiffusion:
|
||||
training_modules = []
|
||||
if device_state_preset in ['cache_latents']:
|
||||
active_modules = ['vae']
|
||||
if device_state_preset in ['cache_clip']:
|
||||
active_modules = ['clip']
|
||||
if device_state_preset in ['generate']:
|
||||
active_modules = ['vae', 'unet', 'text_encoder', 'adapter', 'refiner_unet']
|
||||
|
||||
|
||||
Reference in New Issue
Block a user