Bug fixes. added ability to use l1 loss. varous other tests and improvements

This commit is contained in:
Jaret Burkett
2024-01-31 06:30:54 -07:00
parent 92b9c71d44
commit 1ae1017748
9 changed files with 474 additions and 23 deletions

View File

@@ -6,10 +6,14 @@ import numpy as np
from diffusers import T2IAdapter, AutoencoderTiny
import torch.functional as F
from safetensors.torch import load_file
from torch.utils.data import DataLoader, ConcatDataset
from toolkit import train_tools
from toolkit.basic import value_map, adain, get_mean_std
from toolkit.clip_vision_adapter import ClipVisionAdapter
from toolkit.config_modules import GuidanceConfig
from toolkit.data_loader import get_dataloader_datasets
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO, FileItemDTO
from toolkit.guidance import get_targeted_guidance_loss, get_guidance_loss
from toolkit.image_utils import show_tensors, show_latents
@@ -46,6 +50,8 @@ class SDTrainer(BaseSDTrainProcess):
self.do_guided_loss = False
self.taesd: Optional[AutoencoderTiny] = None
self._clip_image_embeds_unconditional: Union[List[str], None] = None
def before_model_load(self):
pass
@@ -86,6 +92,22 @@ class SDTrainer(BaseSDTrainProcess):
if self.adapter is not None:
self.adapter.to(self.device_torch)
# check if we have regs and using adapter and caching clip embeddings
has_reg = self.datasets_reg is not None and len(self.datasets_reg) > 0
is_caching_clip_embeddings = any([self.datasets[i].cache_clip_vision_to_disk for i in range(len(self.datasets))])
if has_reg and is_caching_clip_embeddings:
# we need a list of unconditional clip image embeds from other datasets to handle regs
unconditional_clip_image_embeds = []
datasets = get_dataloader_datasets(self.data_loader)
for i in range(len(datasets)):
unconditional_clip_image_embeds += datasets[i].clip_vision_unconditional_cache
if len(unconditional_clip_image_embeds) == 0:
raise ValueError("No unconditional clip image embeds found. This should not happen")
self._clip_image_embeds_unconditional = unconditional_clip_image_embeds
def process_output_for_turbo(self, pred, noisy_latents, timesteps, noise, batch):
# to process turbo learning, we make one big step from our current timestep to the end
# we then denoise the prediction on that remaining step and target our loss to our target latents
@@ -190,6 +212,7 @@ class SDTrainer(BaseSDTrainProcess):
**kwargs
):
loss_target = self.train_config.loss_target
is_reg = any(batch.get_is_reg_list())
prior_mask_multiplier = None
target_mask_multiplier = None
@@ -202,6 +225,40 @@ class SDTrainer(BaseSDTrainProcess):
noise_pred_norm = torch.linalg.vector_norm(noise_pred, ord=2, dim=(1, 2, 3), keepdim=True)
noise_pred = noise_pred * (noise_norm / noise_pred_norm)
if self.train_config.correct_pred_norm and not is_reg:
with torch.no_grad():
# adjust the noise target in the opposite direction of the noise pred mean and std offset
# this will apply additional force the model to correct itself to match the norm of the noise
noise_pred_mean, noise_pred_std = get_mean_std(noise_pred)
noise_mean, noise_std = get_mean_std(noise)
# apply the inverse offset of the mean and std to the noise
noise_additional_mean = noise_mean - noise_pred_mean
noise_additional_std = noise_std - noise_pred_std
# adjust for multiplier
noise_additional_mean = noise_additional_mean * self.train_config.correct_pred_norm_multiplier
noise_additional_std = noise_additional_std * self.train_config.correct_pred_norm_multiplier
noise_target_std = noise_std + noise_additional_std
noise_target_mean = noise_mean + noise_additional_mean
noise_pred_target_std = noise_pred_std - noise_additional_std
noise_pred_target_mean = noise_pred_mean - noise_additional_mean
noise_pred_target_std = noise_pred_target_std.detach()
noise_pred_target_mean = noise_pred_target_mean.detach()
# match the noise to the target
noise = (noise - noise_mean) / noise_std
noise = noise * noise_target_std + noise_target_mean
noise = noise.detach()
# meatch the noise pred to the target
# noise_pred = (noise_pred - noise_pred_mean) / noise_pred_std
# noise_pred = noise_pred * noise_pred_target_std + noise_pred_target_mean
if self.train_config.inverted_mask_prior and prior_pred is not None and has_mask:
assert not self.train_config.train_turbo
# we need to make the noise prediction be a masked blending of noise and prior_pred
@@ -227,7 +284,7 @@ class SDTrainer(BaseSDTrainProcess):
target = prior_pred
elif self.sd.prediction_type == 'v_prediction':
# v-parameterization training
target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
target = self.sd.noise_scheduler.get_velocity(batch.tensor, noise, timesteps)
else:
target = noise
@@ -270,7 +327,10 @@ class SDTrainer(BaseSDTrainProcess):
loss_per_element = (weighing.float() * (denoised_latents.float() - target.float()) ** 2)
loss = loss_per_element
else:
loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none")
if self.train_config.loss_type == "mae":
loss = torch.nn.functional.l1_loss(pred.float(), target.float(), reduction="none")
else:
loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none")
# multiply by our mask
loss = loss * mask_multiplier
@@ -278,12 +338,11 @@ class SDTrainer(BaseSDTrainProcess):
prior_loss = None
if self.train_config.inverted_mask_prior and prior_pred is not None and prior_mask_multiplier is not None:
assert not self.train_config.train_turbo
# to a loss to unmasked areas of the prior for unmasked regularization
prior_loss = torch.nn.functional.mse_loss(
prior_pred.float(),
pred.float(),
reduction="none"
)
if self.train_config.loss_type == "mae":
prior_loss = torch.nn.functional.l1_loss(pred.float(), prior_pred.float(), reduction="none")
else:
prior_loss = torch.nn.functional.mse_loss(pred.float(), prior_pred.float(), reduction="none")
prior_loss = prior_loss * prior_mask_multiplier * self.train_config.inverted_mask_prior_multiplier
if torch.isnan(prior_loss).any():
print("Prior loss is nan")
@@ -717,6 +776,13 @@ class SDTrainer(BaseSDTrainProcess):
has_adapter_img = batch.control_tensor is not None
has_clip_image = batch.clip_image_tensor is not None
has_clip_image_embeds = batch.clip_image_embeds is not None
# force it to be true if doing regs as we handle those differently
if any([batch.file_items[idx].is_reg for idx in range(len(batch.file_items))]):
has_clip_image = True
if self._clip_image_embeds_unconditional is not None:
has_clip_image_embeds = True # we are caching embeds, handle that differently
has_clip_image = False
if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not has_clip_image and has_adapter_img:
raise ValueError(
@@ -996,7 +1062,39 @@ class SDTrainer(BaseSDTrainProcess):
# number of images to do if doing a quad image
quad_count = random.randint(1, 4)
image_size = self.adapter.input_size
if is_reg:
if has_clip_image_embeds:
# todo handle reg images better than this
if is_reg:
# get unconditional image imbeds from cache
embeds = [
load_file(random.choice(batch.clip_image_embeds_unconditional)) for i in
range(noisy_latents.shape[0])
]
conditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache(
embeds,
quad_count=quad_count
)
if self.train_config.do_cfg:
embeds = [
load_file(random.choice(batch.clip_image_embeds_unconditional)) for i in range(noisy_latents.shape[0])
]
unconditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache(
embeds,
quad_count=quad_count
)
else:
conditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache(
batch.clip_image_embeds,
quad_count=quad_count
)
if self.train_config.do_cfg:
unconditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache(
batch.clip_image_embeds_unconditional,
quad_count=quad_count
)
elif is_reg:
# we will zero it out in the img embedder
clip_images = torch.zeros(
(noisy_latents.shape[0], 3, image_size, image_size),
@@ -1071,9 +1169,9 @@ class SDTrainer(BaseSDTrainProcess):
prior_pred = None
do_reg_prior = False
# if is_reg and (self.network is not None or self.adapter is not None):
# # we are doing a reg image and we have a network or adapter
# do_reg_prior = True
if is_reg and (self.network is not None or self.adapter is not None):
# we are doing a reg image and we have a network or adapter
do_reg_prior = True
do_inverted_masked_prior = False
if self.train_config.inverted_mask_prior and batch.mask_tensor is not None:
@@ -1096,12 +1194,14 @@ class SDTrainer(BaseSDTrainProcess):
# do the custom adapter after the prior prediction
if self.adapter and isinstance(self.adapter, CustomAdapter) and has_clip_image:
quad_count = random.randint(1, 4)
self.adapter.train()
conditional_embeds = self.adapter.condition_encoded_embeds(
tensors_0_1=clip_images,
prompt_embeds=conditional_embeds,
is_training=True,
has_been_preprocessed=True,
quad_count=quad_count
)
if self.train_config.do_cfg and unconditional_embeds is not None:
unconditional_embeds = self.adapter.condition_encoded_embeds(
@@ -1109,7 +1209,8 @@ class SDTrainer(BaseSDTrainProcess):
prompt_embeds=unconditional_embeds,
is_training=True,
has_been_preprocessed=True,
is_unconditional=True
is_unconditional=True,
quad_count=quad_count
)

View File

@@ -285,6 +285,13 @@ class TrainConfig:
self.cfg_scale = kwargs.get('cfg_scale', 1.0)
self.max_cfg_scale = kwargs.get('max_cfg_scale', self.cfg_scale)
# applies the inverse of the prediction mean and std to the target to correct
# for norm drift
self.correct_pred_norm = kwargs.get('correct_pred_norm', False)
self.correct_pred_norm_multiplier = kwargs.get('correct_pred_norm_multiplier', 1.0)
self.loss_type = kwargs.get('loss_type', 'mse')
class ModelConfig:
def __init__(self, **kwargs):
@@ -444,6 +451,7 @@ class DatasetConfig:
self.cache_latents: bool = kwargs.get('cache_latents', False)
# cache latents to disk will store them on disk. If both are true, it will save to disk, but keep in memory
self.cache_latents_to_disk: bool = kwargs.get('cache_latents_to_disk', False)
self.cache_clip_vision_to_disk: bool = kwargs.get('cache_clip_vision_to_disk', False)
self.standardize_images: bool = kwargs.get('standardize_images', False)

View File

@@ -227,6 +227,16 @@ class CustomAdapter(torch.nn.Module):
self.input_size = self.vision_encoder.config.image_size
if self.config.quad_image: # 4x4 image
# self.clip_image_processor.config
# We do a 3x downscale of the image, so we need to adjust the input size
preprocessor_input_size = self.vision_encoder.config.image_size * 2
# update the preprocessor so images come in at the right size
self.image_processor.size['shortest_edge'] = preprocessor_input_size
self.image_processor.crop_size['height'] = preprocessor_input_size
self.image_processor.crop_size['width'] = preprocessor_input_size
if self.config.image_encoder_arch == 'clip+':
# self.image_processor.config
# We do a 3x downscale of the image, so we need to adjust the input size
@@ -425,7 +435,8 @@ class CustomAdapter(torch.nn.Module):
prompt_embeds: PromptEmbeds,
is_training=False,
has_been_preprocessed=False,
is_unconditional=False
is_unconditional=False,
quad_count=4,
) -> PromptEmbeds:
if self.adapter_type == 'photo_maker' or self.adapter_type == 'clip_fusion' or self.adapter_type == 'ilora':
if is_unconditional:
@@ -454,6 +465,20 @@ class CustomAdapter(torch.nn.Module):
clip_image = tensors_0_1
clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach()
if self.config.quad_image:
# split the 4x4 grid and stack on batch
ci1, ci2 = clip_image.chunk(2, dim=2)
ci1, ci3 = ci1.chunk(2, dim=3)
ci2, ci4 = ci2.chunk(2, dim=3)
to_cat = []
for i, ci in enumerate([ci1, ci2, ci3, ci4]):
if i < quad_count:
to_cat.append(ci)
else:
break
clip_image = torch.cat(to_cat, dim=0).detach()
if self.adapter_type == 'photo_maker':
# Embeddings need to be (b, num_inputs, c, h, w) for now, just put 1 input image
clip_image = clip_image.unsqueeze(1)
@@ -496,6 +521,17 @@ class CustomAdapter(torch.nn.Module):
img_embeds = id_embeds['last_hidden_state']
if self.config.quad_image:
# get the outputs of the quat
chunks = img_embeds.chunk(quad_count, dim=0)
chunk_sum = torch.zeros_like(chunks[0])
for chunk in chunks:
chunk_sum = chunk_sum + chunk
# get the mean of them
img_embeds = chunk_sum / quad_count
if not is_training or not self.config.train_image_encoder:
img_embeds = img_embeds.detach()

View File

@@ -18,7 +18,7 @@ import albumentations as A
from toolkit.buckets import get_bucket_for_image_size, BucketResolution
from toolkit.config_modules import DatasetConfig, preprocess_dataset_raw_config
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin, Augments
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin, Augments, CLIPCachingMixin
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
if TYPE_CHECKING:
@@ -355,7 +355,7 @@ class PairedImageDataset(Dataset):
return img, prompt, (self.neg_weight, self.pos_weight)
class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, CaptionMixin, Dataset):
def __init__(
self,
@@ -373,6 +373,7 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
self.is_caching_latents = dataset_config.cache_latents or dataset_config.cache_latents_to_disk
self.is_caching_latents_to_memory = dataset_config.cache_latents
self.is_caching_latents_to_disk = dataset_config.cache_latents_to_disk
self.is_caching_clip_vision_to_disk = dataset_config.cache_clip_vision_to_disk
self.epoch_num = 0
self.sd = sd
@@ -482,6 +483,8 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
self.setup_buckets()
if self.is_caching_latents:
self.cache_latents_all_latents()
if self.is_caching_clip_vision_to_disk:
self.cache_clip_vision_to_disk()
else:
if self.dataset_config.poi is not None:
# handle cropping to a specific point of interest
@@ -611,3 +614,19 @@ def trigger_dataloader_setup_epoch(dataloader: DataLoader):
if hasattr(sub_dataset, 'setup_epoch'):
sub_dataset.setup_epoch()
sub_dataset.len = None
def get_dataloader_datasets(dataloader: DataLoader):
# hacky but needed because of different types of datasets and dataloaders
if isinstance(dataloader.dataset, list):
datasets = []
for dataset in dataloader.dataset:
if hasattr(dataset, 'datasets'):
for sub_dataset in dataset.datasets:
datasets.append(sub_dataset)
else:
datasets.append(dataset)
return datasets
elif hasattr(dataloader.dataset, 'datasets'):
return dataloader.dataset.datasets
else:
return [dataloader.dataset]

View File

@@ -96,6 +96,8 @@ class DataLoaderBatchDTO:
self.unaugmented_tensor: Union[torch.Tensor, None] = None
self.unconditional_tensor: Union[torch.Tensor, None] = None
self.unconditional_latents: Union[torch.Tensor, None] = None
self.clip_image_embeds: Union[List[dict], None] = None
self.clip_image_embeds_unconditional: Union[List[dict], None] = None
self.sigmas: Union[torch.Tensor, None] = None # can be added elseware and passed along training code
if not is_latents_cached:
# only return a tensor if latents are not cached
@@ -183,6 +185,23 @@ class DataLoaderBatchDTO:
else:
unconditional_tensor.append(x.unconditional_tensor)
self.unconditional_tensor = torch.cat([x.unsqueeze(0) for x in unconditional_tensor])
if any([x.clip_image_embeds is not None for x in self.file_items]):
self.clip_image_embeds = []
for x in self.file_items:
if x.clip_image_embeds is not None:
self.clip_image_embeds.append(x.clip_image_embeds)
else:
raise Exception("clip_image_embeds is None for some file items")
if any([x.clip_image_embeds_unconditional is not None for x in self.file_items]):
self.clip_image_embeds_unconditional = []
for x in self.file_items:
if x.clip_image_embeds_unconditional is not None:
self.clip_image_embeds_unconditional.append(x.clip_image_embeds_unconditional)
else:
raise Exception("clip_image_embeds_unconditional is None for some file items")
except Exception as e:
print(e)
raise e

View File

@@ -12,7 +12,7 @@ import numpy as np
import torch
from safetensors.torch import load_file, save_file
from tqdm import tqdm
from transformers import CLIPImageProcessor
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from toolkit.basic import flush, value_map
from toolkit.buckets import get_bucket_for_image_size, get_resolution
@@ -570,9 +570,18 @@ class ClipImageFileItemDTOMixin:
self.has_clip_image = False
self.clip_image_path: Union[str, None] = None
self.clip_image_tensor: Union[torch.Tensor, None] = None
self.clip_image_embeds: Union[dict, None] = None
self.clip_image_embeds_unconditional: Union[dict, None] = None
self.has_clip_augmentations = False
self.clip_image_aug_transform: Union[None, A.Compose] = None
self.clip_image_processor: Union[None, CLIPImageProcessor] = None
self.clip_image_encoder_path: Union[str, None] = None
self.is_caching_clip_vision_to_disk = False
self.is_vision_clip_cached = False
self.clip_vision_is_quad = False
self.clip_vision_load_device = 'cpu'
self.clip_vision_unconditional_paths: Union[List[str], None] = None
self._clip_vision_embeddings_path: Union[str, None] = None
dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
if dataset_config.clip_image_path is not None:
# copy the clip image processor so the dataloader can do it
@@ -633,7 +642,45 @@ class ClipImageFileItemDTOMixin:
return augmented_tensor
def get_clip_vision_info_dict(self: 'FileItemDTO'):
item = OrderedDict([
("image_encoder_path", self.clip_image_encoder_path),
("filename", os.path.basename(self.clip_image_path)),
("is_quad", self.clip_vision_is_quad)
])
# when adding items, do it after so we dont change old latents
if self.flip_x:
item["flip_x"] = True
if self.flip_y:
item["flip_y"] = True
return item
def get_clip_vision_embeddings_path(self: 'FileItemDTO', recalculate=False):
if self._clip_vision_embeddings_path is not None and not recalculate:
return self._clip_vision_embeddings_path
else:
# we store latents in a folder in same path as image called _latent_cache
img_dir = os.path.dirname(self.clip_image_path)
latent_dir = os.path.join(img_dir, '_clip_vision_cache')
hash_dict = self.get_clip_vision_info_dict()
filename_no_ext = os.path.splitext(os.path.basename(self.clip_image_path))[0]
# get base64 hash of md5 checksum of hash_dict
hash_input = json.dumps(hash_dict, sort_keys=True).encode('utf-8')
hash_str = base64.urlsafe_b64encode(hashlib.md5(hash_input).digest()).decode('ascii')
hash_str = hash_str.replace('=', '')
self._clip_vision_embeddings_path = os.path.join(latent_dir, f'{filename_no_ext}_{hash_str}.safetensors')
return self._clip_vision_embeddings_path
def load_clip_image(self: 'FileItemDTO'):
if self.is_vision_clip_cached:
self.clip_image_embeds = load_file(self.get_clip_vision_embeddings_path())
# get a random unconditional image
if self.clip_vision_unconditional_paths is not None:
unconditional_path = random.choice(self.clip_vision_unconditional_paths)
self.clip_image_embeds_unconditional = load_file(unconditional_path)
return
img = Image.open(self.clip_image_path).convert('RGB')
try:
img = exif_transpose(img)
@@ -683,6 +730,7 @@ class ClipImageFileItemDTOMixin:
def cleanup_clip_image(self: 'FileItemDTO'):
self.clip_image_tensor = None
self.clip_image_embeds = None
@@ -1273,7 +1321,7 @@ class LatentCachingMixin:
del latent
del file_item.tensor
flush(garbage_collect=False)
# flush(garbage_collect=False)
file_item.is_latent_cached = True
i += 1
# flush every 100
@@ -1282,3 +1330,176 @@ class LatentCachingMixin:
# restore device state
self.sd.restore_device_state()
class CLIPCachingMixin:
def __init__(self: 'AiToolkitDataset', **kwargs):
# if we have super, call it
if hasattr(super(), '__init__'):
super().__init__(**kwargs)
self.clip_vision_num_unconditional_cache = 20
self.clip_vision_unconditional_cache = []
def cache_clip_vision_to_disk(self: 'AiToolkitDataset'):
if not self.is_caching_clip_vision_to_disk:
return
with torch.no_grad():
print(f"Caching clip vision for {self.dataset_path}")
print(" - Saving clip to disk")
# move sd items to cpu except for vae
self.sd.set_device_state_preset('cache_clip')
# make sure the adapter has attributes
if self.sd.adapter is None:
raise Exception("Error: must have an adapter to cache clip vision to disk")
clip_image_processor: CLIPImageProcessor = None
if hasattr(self.sd.adapter, 'clip_image_processor'):
clip_image_processor = self.sd.adapter.clip_image_processor
if clip_image_processor is None:
raise Exception("Error: must have a clip image processor to cache clip vision to disk")
vision_encoder: CLIPVisionModelWithProjection = None
if hasattr(self.sd.adapter, 'image_encoder'):
vision_encoder = self.sd.adapter.image_encoder
if hasattr(self.sd.adapter, 'vision_encoder'):
vision_encoder = self.sd.adapter.vision_encoder
if vision_encoder is None:
raise Exception("Error: must have a vision encoder to cache clip vision to disk")
# move vision encoder to device
vision_encoder.to(self.sd.device)
is_quad = self.sd.adapter.config.quad_image
image_encoder_path = self.sd.adapter.config.image_encoder_path
dtype = self.sd.torch_dtype
device = self.sd.device_torch
if hasattr(self.sd.adapter, 'clip_noise_zero') and self.sd.adapter.clip_noise_zero:
# just to do this, we did :)
# need more samples as it is random noise
self.clip_vision_num_unconditional_cache = self.clip_vision_num_unconditional_cache
else:
# only need one since it doesnt change
self.clip_vision_num_unconditional_cache = 1
# cache unconditionals
print(f" - Caching {self.clip_vision_num_unconditional_cache} unconditional clip vision to disk")
clip_vision_cache_path = os.path.join(self.dataset_config.clip_image_path, '_clip_vision_cache')
unconditional_paths = []
is_noise_zero = hasattr(self.sd.adapter, 'clip_noise_zero') and self.sd.adapter.clip_noise_zero
for i in range(self.clip_vision_num_unconditional_cache):
hash_dict = OrderedDict([
("image_encoder_path", image_encoder_path),
("is_quad", is_quad),
("is_noise_zero", is_noise_zero),
])
# get base64 hash of md5 checksum of hash_dict
hash_input = json.dumps(hash_dict, sort_keys=True).encode('utf-8')
hash_str = base64.urlsafe_b64encode(hashlib.md5(hash_input).digest()).decode('ascii')
hash_str = hash_str.replace('=', '')
uncond_path = os.path.join(clip_vision_cache_path, f'uncond_{hash_str}_{i}.safetensors')
if os.path.exists(uncond_path):
# skip it
unconditional_paths.append(uncond_path)
continue
# generate a random image
img_shape = (1, 3, self.sd.adapter.input_size, self.sd.adapter.input_size)
if is_noise_zero:
tensors_0_1 = torch.rand(img_shape).to(device, dtype=torch.float32)
else:
tensors_0_1 = torch.zeros(img_shape).to(device, dtype=torch.float32)
clip_image = clip_image_processor(
images=tensors_0_1,
return_tensors="pt",
do_resize=True,
do_rescale=False,
).pixel_values
if is_quad:
# split the 4x4 grid and stack on batch
ci1, ci2 = clip_image.chunk(2, dim=2)
ci1, ci3 = ci1.chunk(2, dim=3)
ci2, ci4 = ci2.chunk(2, dim=3)
clip_image = torch.cat([ci1, ci2, ci3, ci4], dim=0).detach()
clip_output = vision_encoder(
clip_image.to(device, dtype=dtype),
output_hidden_states=True
)
# make state_dict ['last_hidden_state', 'image_embeds', 'penultimate_hidden_states']
state_dict = OrderedDict([
('image_embeds', clip_output.image_embeds.clone().detach().cpu()),
('last_hidden_state', clip_output.hidden_states[-1].clone().detach().cpu()),
('penultimate_hidden_states', clip_output.hidden_states[-2].clone().detach().cpu()),
])
os.makedirs(os.path.dirname(uncond_path), exist_ok=True)
save_file(state_dict, uncond_path)
unconditional_paths.append(uncond_path)
self.clip_vision_unconditional_cache = unconditional_paths
# use tqdm to show progress
i = 0
for file_item in tqdm(self.file_list, desc=f'Caching clip vision to disk'):
file_item.is_caching_clip_vision_to_disk = True
file_item.clip_vision_load_device = self.sd.device
file_item.clip_vision_is_quad = is_quad
file_item.clip_image_encoder_path = image_encoder_path
file_item.clip_vision_unconditional_paths = unconditional_paths
if file_item.has_clip_augmentations:
raise Exception("Error: clip vision caching is not supported with clip augmentations")
embedding_path = file_item.get_clip_vision_embeddings_path(recalculate=True)
# check if it is saved to disk already
if not os.path.exists(embedding_path):
# load the image first
file_item.load_clip_image()
# add batch dimension
clip_image = file_item.clip_image_tensor.unsqueeze(0).to(device, dtype=dtype)
if is_quad:
# split the 4x4 grid and stack on batch
ci1, ci2 = clip_image.chunk(2, dim=2)
ci1, ci3 = ci1.chunk(2, dim=3)
ci2, ci4 = ci2.chunk(2, dim=3)
clip_image = torch.cat([ci1, ci2, ci3, ci4], dim=0).detach()
clip_output = vision_encoder(
clip_image.to(device, dtype=dtype),
output_hidden_states=True
)
# make state_dict ['last_hidden_state', 'image_embeds', 'penultimate_hidden_states']
state_dict = OrderedDict([
('image_embeds', clip_output.image_embeds.clone().detach().cpu()),
('last_hidden_state', clip_output.hidden_states[-1].clone().detach().cpu()),
('penultimate_hidden_states', clip_output.hidden_states[-2].clone().detach().cpu()),
])
# metadata
meta = get_meta_for_safetensors(file_item.get_clip_vision_info_dict())
os.makedirs(os.path.dirname(embedding_path), exist_ok=True)
save_file(state_dict, embedding_path, metadata=meta)
del clip_image
del clip_output
del file_item.clip_image_tensor
# flush(garbage_collect=False)
file_item.is_vision_clip_cached = True
i += 1
# flush every 100
# if i % 100 == 0:
# flush()
# restore device state
self.sd.restore_device_state()

View File

@@ -249,9 +249,13 @@ class IPAdapter(torch.nn.Module):
preprocessor_input_size = self.image_encoder.config.image_size * 2
# update the preprocessor so images come in at the right size
self.clip_image_processor.size['shortest_edge'] = preprocessor_input_size
self.clip_image_processor.crop_size['height'] = preprocessor_input_size
self.clip_image_processor.crop_size['width'] = preprocessor_input_size
if 'height' in self.clip_image_processor.size:
self.clip_image_processor.size['height'] = preprocessor_input_size
self.clip_image_processor.size['width'] = preprocessor_input_size
elif hasattr(self.clip_image_processor, 'crop_size'):
self.clip_image_processor.size['shortest_edge'] = preprocessor_input_size
self.clip_image_processor.crop_size['height'] = preprocessor_input_size
self.clip_image_processor.crop_size['width'] = preprocessor_input_size
if self.config.image_encoder_arch == 'clip+':
# self.clip_image_processor.config
@@ -439,6 +443,32 @@ class IPAdapter(torch.nn.Module):
if self.preprocessor is not None:
self.preprocessor.to(*args, **kwargs)
return self
def parse_clip_image_embeds_from_cache(
self,
image_embeds_list: List[dict], # has ['last_hidden_state', 'image_embeds', 'penultimate_hidden_states']
quad_count=4,
):
with torch.no_grad():
device = self.sd_ref().unet.device
if self.config.type.startswith('ip+'):
clip_image_embeds = torch.cat([x['penultimate_hidden_states'] for x in image_embeds_list], dim=0)
else:
clip_image_embeds = torch.cat([x['image_embeds'] for x in image_embeds_list], dim=0)
if self.config.quad_image:
# get the outputs of the quat
chunks = clip_image_embeds.chunk(quad_count, dim=0)
chunk_sum = torch.zeros_like(chunks[0])
for chunk in chunks:
chunk_sum = chunk_sum + chunk
# get the mean of them
clip_image_embeds = chunk_sum / quad_count
clip_image_embeds = clip_image_embeds.to(device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach()
return clip_image_embeds
def get_clip_image_embeds_from_tensors(
self,
tensors_0_1: torch.Tensor,

View File

@@ -172,6 +172,8 @@ class CLIPFusionModule(nn.Module):
dim=self.text_hidden_size,
)
self.alpha = nn.Parameter(torch.zeros([text_tokens]) + 0.01)
def forward(self, text_embeds, vision_embeds):
# text_embeds = (batch_size, 77, 768)
# vision_embeds = (batch_size, 257, 1024)
@@ -186,7 +188,12 @@ class CLIPFusionModule(nn.Module):
x = x + res
# alpha mask
alpha = self.ctx_alpha(text_embeds)
x = alpha * x + (1 - alpha) * text_embeds
ctx_alpha = self.ctx_alpha(text_embeds)
# reshape alpha to (1, 77, 1)
alpha = self.alpha.unsqueeze(0).unsqueeze(-1)
x = ctx_alpha * x * alpha
x = x + text_embeds
return x

View File

@@ -841,6 +841,14 @@ class StableDiffusion:
else:
timestep = timestep.repeat(latents.shape[0], 0)
# handle t2i adapters
if 'down_intrablock_additional_residuals' in kwargs:
# go through each item and concat if doing cfg and it doesnt have the same shape
for idx, item in enumerate(kwargs['down_intrablock_additional_residuals']):
if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]:
kwargs['down_intrablock_additional_residuals'][idx] = torch.cat([item] * 2, dim=0)
def scale_model_input(model_input, timestep_tensor):
if is_input_scaled:
return model_input
@@ -1599,6 +1607,8 @@ class StableDiffusion:
training_modules = []
if device_state_preset in ['cache_latents']:
active_modules = ['vae']
if device_state_preset in ['cache_clip']:
active_modules = ['clip']
if device_state_preset in ['generate']:
active_modules = ['vae', 'unet', 'text_encoder', 'adapter', 'refiner_unet']