diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index b9e5ed69..ffe8aff6 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -6,10 +6,14 @@ import numpy as np from diffusers import T2IAdapter, AutoencoderTiny import torch.functional as F +from safetensors.torch import load_file +from torch.utils.data import DataLoader, ConcatDataset + from toolkit import train_tools from toolkit.basic import value_map, adain, get_mean_std from toolkit.clip_vision_adapter import ClipVisionAdapter from toolkit.config_modules import GuidanceConfig +from toolkit.data_loader import get_dataloader_datasets from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO, FileItemDTO from toolkit.guidance import get_targeted_guidance_loss, get_guidance_loss from toolkit.image_utils import show_tensors, show_latents @@ -46,6 +50,8 @@ class SDTrainer(BaseSDTrainProcess): self.do_guided_loss = False self.taesd: Optional[AutoencoderTiny] = None + self._clip_image_embeds_unconditional: Union[List[str], None] = None + def before_model_load(self): pass @@ -86,6 +92,22 @@ class SDTrainer(BaseSDTrainProcess): if self.adapter is not None: self.adapter.to(self.device_torch) + # check if we have regs and using adapter and caching clip embeddings + has_reg = self.datasets_reg is not None and len(self.datasets_reg) > 0 + is_caching_clip_embeddings = any([self.datasets[i].cache_clip_vision_to_disk for i in range(len(self.datasets))]) + + if has_reg and is_caching_clip_embeddings: + # we need a list of unconditional clip image embeds from other datasets to handle regs + unconditional_clip_image_embeds = [] + datasets = get_dataloader_datasets(self.data_loader) + for i in range(len(datasets)): + unconditional_clip_image_embeds += datasets[i].clip_vision_unconditional_cache + + if len(unconditional_clip_image_embeds) == 0: + raise ValueError("No unconditional clip image embeds found. This should not happen") + + self._clip_image_embeds_unconditional = unconditional_clip_image_embeds + def process_output_for_turbo(self, pred, noisy_latents, timesteps, noise, batch): # to process turbo learning, we make one big step from our current timestep to the end # we then denoise the prediction on that remaining step and target our loss to our target latents @@ -190,6 +212,7 @@ class SDTrainer(BaseSDTrainProcess): **kwargs ): loss_target = self.train_config.loss_target + is_reg = any(batch.get_is_reg_list()) prior_mask_multiplier = None target_mask_multiplier = None @@ -202,6 +225,40 @@ class SDTrainer(BaseSDTrainProcess): noise_pred_norm = torch.linalg.vector_norm(noise_pred, ord=2, dim=(1, 2, 3), keepdim=True) noise_pred = noise_pred * (noise_norm / noise_pred_norm) + if self.train_config.correct_pred_norm and not is_reg: + with torch.no_grad(): + + # adjust the noise target in the opposite direction of the noise pred mean and std offset + # this will apply additional force the model to correct itself to match the norm of the noise + noise_pred_mean, noise_pred_std = get_mean_std(noise_pred) + noise_mean, noise_std = get_mean_std(noise) + + # apply the inverse offset of the mean and std to the noise + noise_additional_mean = noise_mean - noise_pred_mean + noise_additional_std = noise_std - noise_pred_std + + # adjust for multiplier + noise_additional_mean = noise_additional_mean * self.train_config.correct_pred_norm_multiplier + noise_additional_std = noise_additional_std * self.train_config.correct_pred_norm_multiplier + + noise_target_std = noise_std + noise_additional_std + noise_target_mean = noise_mean + noise_additional_mean + + + noise_pred_target_std = noise_pred_std - noise_additional_std + noise_pred_target_mean = noise_pred_mean - noise_additional_mean + noise_pred_target_std = noise_pred_target_std.detach() + noise_pred_target_mean = noise_pred_target_mean.detach() + + # match the noise to the target + noise = (noise - noise_mean) / noise_std + noise = noise * noise_target_std + noise_target_mean + noise = noise.detach() + + # meatch the noise pred to the target + # noise_pred = (noise_pred - noise_pred_mean) / noise_pred_std + # noise_pred = noise_pred * noise_pred_target_std + noise_pred_target_mean + if self.train_config.inverted_mask_prior and prior_pred is not None and has_mask: assert not self.train_config.train_turbo # we need to make the noise prediction be a masked blending of noise and prior_pred @@ -227,7 +284,7 @@ class SDTrainer(BaseSDTrainProcess): target = prior_pred elif self.sd.prediction_type == 'v_prediction': # v-parameterization training - target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps) + target = self.sd.noise_scheduler.get_velocity(batch.tensor, noise, timesteps) else: target = noise @@ -270,7 +327,10 @@ class SDTrainer(BaseSDTrainProcess): loss_per_element = (weighing.float() * (denoised_latents.float() - target.float()) ** 2) loss = loss_per_element else: - loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none") + if self.train_config.loss_type == "mae": + loss = torch.nn.functional.l1_loss(pred.float(), target.float(), reduction="none") + else: + loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none") # multiply by our mask loss = loss * mask_multiplier @@ -278,12 +338,11 @@ class SDTrainer(BaseSDTrainProcess): prior_loss = None if self.train_config.inverted_mask_prior and prior_pred is not None and prior_mask_multiplier is not None: assert not self.train_config.train_turbo - # to a loss to unmasked areas of the prior for unmasked regularization - prior_loss = torch.nn.functional.mse_loss( - prior_pred.float(), - pred.float(), - reduction="none" - ) + if self.train_config.loss_type == "mae": + prior_loss = torch.nn.functional.l1_loss(pred.float(), prior_pred.float(), reduction="none") + else: + prior_loss = torch.nn.functional.mse_loss(pred.float(), prior_pred.float(), reduction="none") + prior_loss = prior_loss * prior_mask_multiplier * self.train_config.inverted_mask_prior_multiplier if torch.isnan(prior_loss).any(): print("Prior loss is nan") @@ -717,6 +776,13 @@ class SDTrainer(BaseSDTrainProcess): has_adapter_img = batch.control_tensor is not None has_clip_image = batch.clip_image_tensor is not None + has_clip_image_embeds = batch.clip_image_embeds is not None + # force it to be true if doing regs as we handle those differently + if any([batch.file_items[idx].is_reg for idx in range(len(batch.file_items))]): + has_clip_image = True + if self._clip_image_embeds_unconditional is not None: + has_clip_image_embeds = True # we are caching embeds, handle that differently + has_clip_image = False if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not has_clip_image and has_adapter_img: raise ValueError( @@ -996,7 +1062,39 @@ class SDTrainer(BaseSDTrainProcess): # number of images to do if doing a quad image quad_count = random.randint(1, 4) image_size = self.adapter.input_size - if is_reg: + if has_clip_image_embeds: + # todo handle reg images better than this + if is_reg: + # get unconditional image imbeds from cache + embeds = [ + load_file(random.choice(batch.clip_image_embeds_unconditional)) for i in + range(noisy_latents.shape[0]) + ] + conditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache( + embeds, + quad_count=quad_count + ) + + if self.train_config.do_cfg: + embeds = [ + load_file(random.choice(batch.clip_image_embeds_unconditional)) for i in range(noisy_latents.shape[0]) + ] + unconditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache( + embeds, + quad_count=quad_count + ) + + else: + conditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache( + batch.clip_image_embeds, + quad_count=quad_count + ) + if self.train_config.do_cfg: + unconditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache( + batch.clip_image_embeds_unconditional, + quad_count=quad_count + ) + elif is_reg: # we will zero it out in the img embedder clip_images = torch.zeros( (noisy_latents.shape[0], 3, image_size, image_size), @@ -1071,9 +1169,9 @@ class SDTrainer(BaseSDTrainProcess): prior_pred = None do_reg_prior = False - # if is_reg and (self.network is not None or self.adapter is not None): - # # we are doing a reg image and we have a network or adapter - # do_reg_prior = True + if is_reg and (self.network is not None or self.adapter is not None): + # we are doing a reg image and we have a network or adapter + do_reg_prior = True do_inverted_masked_prior = False if self.train_config.inverted_mask_prior and batch.mask_tensor is not None: @@ -1096,12 +1194,14 @@ class SDTrainer(BaseSDTrainProcess): # do the custom adapter after the prior prediction if self.adapter and isinstance(self.adapter, CustomAdapter) and has_clip_image: + quad_count = random.randint(1, 4) self.adapter.train() conditional_embeds = self.adapter.condition_encoded_embeds( tensors_0_1=clip_images, prompt_embeds=conditional_embeds, is_training=True, has_been_preprocessed=True, + quad_count=quad_count ) if self.train_config.do_cfg and unconditional_embeds is not None: unconditional_embeds = self.adapter.condition_encoded_embeds( @@ -1109,7 +1209,8 @@ class SDTrainer(BaseSDTrainProcess): prompt_embeds=unconditional_embeds, is_training=True, has_been_preprocessed=True, - is_unconditional=True + is_unconditional=True, + quad_count=quad_count ) diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index de9f6eba..b8b788df 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -285,6 +285,13 @@ class TrainConfig: self.cfg_scale = kwargs.get('cfg_scale', 1.0) self.max_cfg_scale = kwargs.get('max_cfg_scale', self.cfg_scale) + # applies the inverse of the prediction mean and std to the target to correct + # for norm drift + self.correct_pred_norm = kwargs.get('correct_pred_norm', False) + self.correct_pred_norm_multiplier = kwargs.get('correct_pred_norm_multiplier', 1.0) + + self.loss_type = kwargs.get('loss_type', 'mse') + class ModelConfig: def __init__(self, **kwargs): @@ -444,6 +451,7 @@ class DatasetConfig: self.cache_latents: bool = kwargs.get('cache_latents', False) # cache latents to disk will store them on disk. If both are true, it will save to disk, but keep in memory self.cache_latents_to_disk: bool = kwargs.get('cache_latents_to_disk', False) + self.cache_clip_vision_to_disk: bool = kwargs.get('cache_clip_vision_to_disk', False) self.standardize_images: bool = kwargs.get('standardize_images', False) diff --git a/toolkit/custom_adapter.py b/toolkit/custom_adapter.py index 810f7d67..3e02f497 100644 --- a/toolkit/custom_adapter.py +++ b/toolkit/custom_adapter.py @@ -227,6 +227,16 @@ class CustomAdapter(torch.nn.Module): self.input_size = self.vision_encoder.config.image_size + if self.config.quad_image: # 4x4 image + # self.clip_image_processor.config + # We do a 3x downscale of the image, so we need to adjust the input size + preprocessor_input_size = self.vision_encoder.config.image_size * 2 + + # update the preprocessor so images come in at the right size + self.image_processor.size['shortest_edge'] = preprocessor_input_size + self.image_processor.crop_size['height'] = preprocessor_input_size + self.image_processor.crop_size['width'] = preprocessor_input_size + if self.config.image_encoder_arch == 'clip+': # self.image_processor.config # We do a 3x downscale of the image, so we need to adjust the input size @@ -425,7 +435,8 @@ class CustomAdapter(torch.nn.Module): prompt_embeds: PromptEmbeds, is_training=False, has_been_preprocessed=False, - is_unconditional=False + is_unconditional=False, + quad_count=4, ) -> PromptEmbeds: if self.adapter_type == 'photo_maker' or self.adapter_type == 'clip_fusion' or self.adapter_type == 'ilora': if is_unconditional: @@ -454,6 +465,20 @@ class CustomAdapter(torch.nn.Module): clip_image = tensors_0_1 clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach() + if self.config.quad_image: + # split the 4x4 grid and stack on batch + ci1, ci2 = clip_image.chunk(2, dim=2) + ci1, ci3 = ci1.chunk(2, dim=3) + ci2, ci4 = ci2.chunk(2, dim=3) + to_cat = [] + for i, ci in enumerate([ci1, ci2, ci3, ci4]): + if i < quad_count: + to_cat.append(ci) + else: + break + + clip_image = torch.cat(to_cat, dim=0).detach() + if self.adapter_type == 'photo_maker': # Embeddings need to be (b, num_inputs, c, h, w) for now, just put 1 input image clip_image = clip_image.unsqueeze(1) @@ -496,6 +521,17 @@ class CustomAdapter(torch.nn.Module): img_embeds = id_embeds['last_hidden_state'] + if self.config.quad_image: + # get the outputs of the quat + chunks = img_embeds.chunk(quad_count, dim=0) + chunk_sum = torch.zeros_like(chunks[0]) + for chunk in chunks: + chunk_sum = chunk_sum + chunk + # get the mean of them + + img_embeds = chunk_sum / quad_count + + if not is_training or not self.config.train_image_encoder: img_embeds = img_embeds.detach() diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index aa518788..df89544e 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -18,7 +18,7 @@ import albumentations as A from toolkit.buckets import get_bucket_for_image_size, BucketResolution from toolkit.config_modules import DatasetConfig, preprocess_dataset_raw_config -from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin, Augments +from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin, Augments, CLIPCachingMixin from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO if TYPE_CHECKING: @@ -355,7 +355,7 @@ class PairedImageDataset(Dataset): return img, prompt, (self.neg_weight, self.pos_weight) -class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset): +class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, CaptionMixin, Dataset): def __init__( self, @@ -373,6 +373,7 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset): self.is_caching_latents = dataset_config.cache_latents or dataset_config.cache_latents_to_disk self.is_caching_latents_to_memory = dataset_config.cache_latents self.is_caching_latents_to_disk = dataset_config.cache_latents_to_disk + self.is_caching_clip_vision_to_disk = dataset_config.cache_clip_vision_to_disk self.epoch_num = 0 self.sd = sd @@ -482,6 +483,8 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset): self.setup_buckets() if self.is_caching_latents: self.cache_latents_all_latents() + if self.is_caching_clip_vision_to_disk: + self.cache_clip_vision_to_disk() else: if self.dataset_config.poi is not None: # handle cropping to a specific point of interest @@ -611,3 +614,19 @@ def trigger_dataloader_setup_epoch(dataloader: DataLoader): if hasattr(sub_dataset, 'setup_epoch'): sub_dataset.setup_epoch() sub_dataset.len = None + +def get_dataloader_datasets(dataloader: DataLoader): + # hacky but needed because of different types of datasets and dataloaders + if isinstance(dataloader.dataset, list): + datasets = [] + for dataset in dataloader.dataset: + if hasattr(dataset, 'datasets'): + for sub_dataset in dataset.datasets: + datasets.append(sub_dataset) + else: + datasets.append(dataset) + return datasets + elif hasattr(dataloader.dataset, 'datasets'): + return dataloader.dataset.datasets + else: + return [dataloader.dataset] diff --git a/toolkit/data_transfer_object/data_loader.py b/toolkit/data_transfer_object/data_loader.py index b06feb39..ba0a6e4b 100644 --- a/toolkit/data_transfer_object/data_loader.py +++ b/toolkit/data_transfer_object/data_loader.py @@ -96,6 +96,8 @@ class DataLoaderBatchDTO: self.unaugmented_tensor: Union[torch.Tensor, None] = None self.unconditional_tensor: Union[torch.Tensor, None] = None self.unconditional_latents: Union[torch.Tensor, None] = None + self.clip_image_embeds: Union[List[dict], None] = None + self.clip_image_embeds_unconditional: Union[List[dict], None] = None self.sigmas: Union[torch.Tensor, None] = None # can be added elseware and passed along training code if not is_latents_cached: # only return a tensor if latents are not cached @@ -183,6 +185,23 @@ class DataLoaderBatchDTO: else: unconditional_tensor.append(x.unconditional_tensor) self.unconditional_tensor = torch.cat([x.unsqueeze(0) for x in unconditional_tensor]) + + if any([x.clip_image_embeds is not None for x in self.file_items]): + self.clip_image_embeds = [] + for x in self.file_items: + if x.clip_image_embeds is not None: + self.clip_image_embeds.append(x.clip_image_embeds) + else: + raise Exception("clip_image_embeds is None for some file items") + + if any([x.clip_image_embeds_unconditional is not None for x in self.file_items]): + self.clip_image_embeds_unconditional = [] + for x in self.file_items: + if x.clip_image_embeds_unconditional is not None: + self.clip_image_embeds_unconditional.append(x.clip_image_embeds_unconditional) + else: + raise Exception("clip_image_embeds_unconditional is None for some file items") + except Exception as e: print(e) raise e diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index ffd13156..302cd9ea 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -12,7 +12,7 @@ import numpy as np import torch from safetensors.torch import load_file, save_file from tqdm import tqdm -from transformers import CLIPImageProcessor +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from toolkit.basic import flush, value_map from toolkit.buckets import get_bucket_for_image_size, get_resolution @@ -570,9 +570,18 @@ class ClipImageFileItemDTOMixin: self.has_clip_image = False self.clip_image_path: Union[str, None] = None self.clip_image_tensor: Union[torch.Tensor, None] = None + self.clip_image_embeds: Union[dict, None] = None + self.clip_image_embeds_unconditional: Union[dict, None] = None self.has_clip_augmentations = False self.clip_image_aug_transform: Union[None, A.Compose] = None self.clip_image_processor: Union[None, CLIPImageProcessor] = None + self.clip_image_encoder_path: Union[str, None] = None + self.is_caching_clip_vision_to_disk = False + self.is_vision_clip_cached = False + self.clip_vision_is_quad = False + self.clip_vision_load_device = 'cpu' + self.clip_vision_unconditional_paths: Union[List[str], None] = None + self._clip_vision_embeddings_path: Union[str, None] = None dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None) if dataset_config.clip_image_path is not None: # copy the clip image processor so the dataloader can do it @@ -633,7 +642,45 @@ class ClipImageFileItemDTOMixin: return augmented_tensor + def get_clip_vision_info_dict(self: 'FileItemDTO'): + item = OrderedDict([ + ("image_encoder_path", self.clip_image_encoder_path), + ("filename", os.path.basename(self.clip_image_path)), + ("is_quad", self.clip_vision_is_quad) + ]) + # when adding items, do it after so we dont change old latents + if self.flip_x: + item["flip_x"] = True + if self.flip_y: + item["flip_y"] = True + return item + def get_clip_vision_embeddings_path(self: 'FileItemDTO', recalculate=False): + if self._clip_vision_embeddings_path is not None and not recalculate: + return self._clip_vision_embeddings_path + else: + # we store latents in a folder in same path as image called _latent_cache + img_dir = os.path.dirname(self.clip_image_path) + latent_dir = os.path.join(img_dir, '_clip_vision_cache') + hash_dict = self.get_clip_vision_info_dict() + filename_no_ext = os.path.splitext(os.path.basename(self.clip_image_path))[0] + # get base64 hash of md5 checksum of hash_dict + hash_input = json.dumps(hash_dict, sort_keys=True).encode('utf-8') + hash_str = base64.urlsafe_b64encode(hashlib.md5(hash_input).digest()).decode('ascii') + hash_str = hash_str.replace('=', '') + self._clip_vision_embeddings_path = os.path.join(latent_dir, f'{filename_no_ext}_{hash_str}.safetensors') + + return self._clip_vision_embeddings_path + def load_clip_image(self: 'FileItemDTO'): + if self.is_vision_clip_cached: + self.clip_image_embeds = load_file(self.get_clip_vision_embeddings_path()) + + # get a random unconditional image + if self.clip_vision_unconditional_paths is not None: + unconditional_path = random.choice(self.clip_vision_unconditional_paths) + self.clip_image_embeds_unconditional = load_file(unconditional_path) + + return img = Image.open(self.clip_image_path).convert('RGB') try: img = exif_transpose(img) @@ -683,6 +730,7 @@ class ClipImageFileItemDTOMixin: def cleanup_clip_image(self: 'FileItemDTO'): self.clip_image_tensor = None + self.clip_image_embeds = None @@ -1273,7 +1321,7 @@ class LatentCachingMixin: del latent del file_item.tensor - flush(garbage_collect=False) + # flush(garbage_collect=False) file_item.is_latent_cached = True i += 1 # flush every 100 @@ -1282,3 +1330,176 @@ class LatentCachingMixin: # restore device state self.sd.restore_device_state() + + +class CLIPCachingMixin: + def __init__(self: 'AiToolkitDataset', **kwargs): + # if we have super, call it + if hasattr(super(), '__init__'): + super().__init__(**kwargs) + self.clip_vision_num_unconditional_cache = 20 + self.clip_vision_unconditional_cache = [] + + def cache_clip_vision_to_disk(self: 'AiToolkitDataset'): + if not self.is_caching_clip_vision_to_disk: + return + with torch.no_grad(): + print(f"Caching clip vision for {self.dataset_path}") + + print(" - Saving clip to disk") + # move sd items to cpu except for vae + self.sd.set_device_state_preset('cache_clip') + + # make sure the adapter has attributes + if self.sd.adapter is None: + raise Exception("Error: must have an adapter to cache clip vision to disk") + + clip_image_processor: CLIPImageProcessor = None + if hasattr(self.sd.adapter, 'clip_image_processor'): + clip_image_processor = self.sd.adapter.clip_image_processor + + if clip_image_processor is None: + raise Exception("Error: must have a clip image processor to cache clip vision to disk") + + vision_encoder: CLIPVisionModelWithProjection = None + if hasattr(self.sd.adapter, 'image_encoder'): + vision_encoder = self.sd.adapter.image_encoder + if hasattr(self.sd.adapter, 'vision_encoder'): + vision_encoder = self.sd.adapter.vision_encoder + + if vision_encoder is None: + raise Exception("Error: must have a vision encoder to cache clip vision to disk") + + # move vision encoder to device + vision_encoder.to(self.sd.device) + + is_quad = self.sd.adapter.config.quad_image + image_encoder_path = self.sd.adapter.config.image_encoder_path + + dtype = self.sd.torch_dtype + device = self.sd.device_torch + if hasattr(self.sd.adapter, 'clip_noise_zero') and self.sd.adapter.clip_noise_zero: + # just to do this, we did :) + # need more samples as it is random noise + self.clip_vision_num_unconditional_cache = self.clip_vision_num_unconditional_cache + else: + # only need one since it doesnt change + self.clip_vision_num_unconditional_cache = 1 + + # cache unconditionals + print(f" - Caching {self.clip_vision_num_unconditional_cache} unconditional clip vision to disk") + clip_vision_cache_path = os.path.join(self.dataset_config.clip_image_path, '_clip_vision_cache') + + unconditional_paths = [] + + is_noise_zero = hasattr(self.sd.adapter, 'clip_noise_zero') and self.sd.adapter.clip_noise_zero + + for i in range(self.clip_vision_num_unconditional_cache): + hash_dict = OrderedDict([ + ("image_encoder_path", image_encoder_path), + ("is_quad", is_quad), + ("is_noise_zero", is_noise_zero), + ]) + # get base64 hash of md5 checksum of hash_dict + hash_input = json.dumps(hash_dict, sort_keys=True).encode('utf-8') + hash_str = base64.urlsafe_b64encode(hashlib.md5(hash_input).digest()).decode('ascii') + hash_str = hash_str.replace('=', '') + + uncond_path = os.path.join(clip_vision_cache_path, f'uncond_{hash_str}_{i}.safetensors') + if os.path.exists(uncond_path): + # skip it + unconditional_paths.append(uncond_path) + continue + + # generate a random image + img_shape = (1, 3, self.sd.adapter.input_size, self.sd.adapter.input_size) + if is_noise_zero: + tensors_0_1 = torch.rand(img_shape).to(device, dtype=torch.float32) + else: + tensors_0_1 = torch.zeros(img_shape).to(device, dtype=torch.float32) + clip_image = clip_image_processor( + images=tensors_0_1, + return_tensors="pt", + do_resize=True, + do_rescale=False, + ).pixel_values + + if is_quad: + # split the 4x4 grid and stack on batch + ci1, ci2 = clip_image.chunk(2, dim=2) + ci1, ci3 = ci1.chunk(2, dim=3) + ci2, ci4 = ci2.chunk(2, dim=3) + clip_image = torch.cat([ci1, ci2, ci3, ci4], dim=0).detach() + + clip_output = vision_encoder( + clip_image.to(device, dtype=dtype), + output_hidden_states=True + ) + # make state_dict ['last_hidden_state', 'image_embeds', 'penultimate_hidden_states'] + state_dict = OrderedDict([ + ('image_embeds', clip_output.image_embeds.clone().detach().cpu()), + ('last_hidden_state', clip_output.hidden_states[-1].clone().detach().cpu()), + ('penultimate_hidden_states', clip_output.hidden_states[-2].clone().detach().cpu()), + ]) + + os.makedirs(os.path.dirname(uncond_path), exist_ok=True) + save_file(state_dict, uncond_path) + unconditional_paths.append(uncond_path) + + self.clip_vision_unconditional_cache = unconditional_paths + + # use tqdm to show progress + i = 0 + for file_item in tqdm(self.file_list, desc=f'Caching clip vision to disk'): + file_item.is_caching_clip_vision_to_disk = True + file_item.clip_vision_load_device = self.sd.device + file_item.clip_vision_is_quad = is_quad + file_item.clip_image_encoder_path = image_encoder_path + file_item.clip_vision_unconditional_paths = unconditional_paths + if file_item.has_clip_augmentations: + raise Exception("Error: clip vision caching is not supported with clip augmentations") + + embedding_path = file_item.get_clip_vision_embeddings_path(recalculate=True) + # check if it is saved to disk already + if not os.path.exists(embedding_path): + # load the image first + file_item.load_clip_image() + # add batch dimension + clip_image = file_item.clip_image_tensor.unsqueeze(0).to(device, dtype=dtype) + + if is_quad: + # split the 4x4 grid and stack on batch + ci1, ci2 = clip_image.chunk(2, dim=2) + ci1, ci3 = ci1.chunk(2, dim=3) + ci2, ci4 = ci2.chunk(2, dim=3) + clip_image = torch.cat([ci1, ci2, ci3, ci4], dim=0).detach() + + clip_output = vision_encoder( + clip_image.to(device, dtype=dtype), + output_hidden_states=True + ) + + # make state_dict ['last_hidden_state', 'image_embeds', 'penultimate_hidden_states'] + state_dict = OrderedDict([ + ('image_embeds', clip_output.image_embeds.clone().detach().cpu()), + ('last_hidden_state', clip_output.hidden_states[-1].clone().detach().cpu()), + ('penultimate_hidden_states', clip_output.hidden_states[-2].clone().detach().cpu()), + ]) + # metadata + meta = get_meta_for_safetensors(file_item.get_clip_vision_info_dict()) + os.makedirs(os.path.dirname(embedding_path), exist_ok=True) + save_file(state_dict, embedding_path, metadata=meta) + + del clip_image + del clip_output + del file_item.clip_image_tensor + + # flush(garbage_collect=False) + file_item.is_vision_clip_cached = True + i += 1 + # flush every 100 + # if i % 100 == 0: + # flush() + + # restore device state + self.sd.restore_device_state() diff --git a/toolkit/ip_adapter.py b/toolkit/ip_adapter.py index fce13e90..6ad675b4 100644 --- a/toolkit/ip_adapter.py +++ b/toolkit/ip_adapter.py @@ -249,9 +249,13 @@ class IPAdapter(torch.nn.Module): preprocessor_input_size = self.image_encoder.config.image_size * 2 # update the preprocessor so images come in at the right size - self.clip_image_processor.size['shortest_edge'] = preprocessor_input_size - self.clip_image_processor.crop_size['height'] = preprocessor_input_size - self.clip_image_processor.crop_size['width'] = preprocessor_input_size + if 'height' in self.clip_image_processor.size: + self.clip_image_processor.size['height'] = preprocessor_input_size + self.clip_image_processor.size['width'] = preprocessor_input_size + elif hasattr(self.clip_image_processor, 'crop_size'): + self.clip_image_processor.size['shortest_edge'] = preprocessor_input_size + self.clip_image_processor.crop_size['height'] = preprocessor_input_size + self.clip_image_processor.crop_size['width'] = preprocessor_input_size if self.config.image_encoder_arch == 'clip+': # self.clip_image_processor.config @@ -439,6 +443,32 @@ class IPAdapter(torch.nn.Module): if self.preprocessor is not None: self.preprocessor.to(*args, **kwargs) return self + + def parse_clip_image_embeds_from_cache( + self, + image_embeds_list: List[dict], # has ['last_hidden_state', 'image_embeds', 'penultimate_hidden_states'] + quad_count=4, + ): + with torch.no_grad(): + device = self.sd_ref().unet.device + if self.config.type.startswith('ip+'): + clip_image_embeds = torch.cat([x['penultimate_hidden_states'] for x in image_embeds_list], dim=0) + else: + clip_image_embeds = torch.cat([x['image_embeds'] for x in image_embeds_list], dim=0) + + if self.config.quad_image: + # get the outputs of the quat + chunks = clip_image_embeds.chunk(quad_count, dim=0) + chunk_sum = torch.zeros_like(chunks[0]) + for chunk in chunks: + chunk_sum = chunk_sum + chunk + # get the mean of them + + clip_image_embeds = chunk_sum / quad_count + + clip_image_embeds = clip_image_embeds.to(device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach() + return clip_image_embeds + def get_clip_image_embeds_from_tensors( self, tensors_0_1: torch.Tensor, diff --git a/toolkit/models/clip_fusion.py b/toolkit/models/clip_fusion.py index 73fb3ffb..d2674175 100644 --- a/toolkit/models/clip_fusion.py +++ b/toolkit/models/clip_fusion.py @@ -172,6 +172,8 @@ class CLIPFusionModule(nn.Module): dim=self.text_hidden_size, ) + self.alpha = nn.Parameter(torch.zeros([text_tokens]) + 0.01) + def forward(self, text_embeds, vision_embeds): # text_embeds = (batch_size, 77, 768) # vision_embeds = (batch_size, 257, 1024) @@ -186,7 +188,12 @@ class CLIPFusionModule(nn.Module): x = x + res # alpha mask - alpha = self.ctx_alpha(text_embeds) - x = alpha * x + (1 - alpha) * text_embeds + ctx_alpha = self.ctx_alpha(text_embeds) + # reshape alpha to (1, 77, 1) + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) + + x = ctx_alpha * x * alpha + + x = x + text_embeds return x diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index c81ef442..f485366c 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -841,6 +841,14 @@ class StableDiffusion: else: timestep = timestep.repeat(latents.shape[0], 0) + + # handle t2i adapters + if 'down_intrablock_additional_residuals' in kwargs: + # go through each item and concat if doing cfg and it doesnt have the same shape + for idx, item in enumerate(kwargs['down_intrablock_additional_residuals']): + if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]: + kwargs['down_intrablock_additional_residuals'][idx] = torch.cat([item] * 2, dim=0) + def scale_model_input(model_input, timestep_tensor): if is_input_scaled: return model_input @@ -1599,6 +1607,8 @@ class StableDiffusion: training_modules = [] if device_state_preset in ['cache_latents']: active_modules = ['vae'] + if device_state_preset in ['cache_clip']: + active_modules = ['clip'] if device_state_preset in ['generate']: active_modules = ['vae', 'unet', 'text_encoder', 'adapter', 'refiner_unet']