From 92b9c71d4418394d3ba224a04ba8034273acaa42 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sun, 28 Jan 2024 08:20:03 -0700 Subject: [PATCH] Many bug fixes. Ip adapter bug fixes. Added noise to unconditional, it works better. added an ilora adapter for 1 shotting LoRAs --- extensions_built_in/sd_trainer/SDTrainer.py | 24 ++-- jobs/process/BaseSDTrainProcess.py | 37 +++-- toolkit/config_modules.py | 8 ++ toolkit/custom_adapter.py | 45 ++++-- toolkit/dataloader_mixins.py | 15 ++ toolkit/ip_adapter.py | 150 +++++++++++++++++--- toolkit/lora_special.py | 9 +- toolkit/metadata.py | 2 + toolkit/models/ilora.py | 104 ++++++++++++++ toolkit/saving.py | 14 +- 10 files changed, 352 insertions(+), 56 deletions(-) create mode 100644 toolkit/models/ilora.py diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 4dc8669b..b9e5ed69 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -286,9 +286,10 @@ class SDTrainer(BaseSDTrainProcess): ) prior_loss = prior_loss * prior_mask_multiplier * self.train_config.inverted_mask_prior_multiplier if torch.isnan(prior_loss).any(): - raise ValueError("Prior loss is nan") - - prior_loss = prior_loss.mean([1, 2, 3]) + print("Prior loss is nan") + prior_loss = None + else: + prior_loss = prior_loss.mean([1, 2, 3]) # loss = loss + prior_loss loss = loss.mean([1, 2, 3]) if prior_loss is not None: @@ -992,6 +993,8 @@ class SDTrainer(BaseSDTrainProcess): if self.adapter and isinstance(self.adapter, IPAdapter): with self.timer('encode_adapter_embeds'): + # number of images to do if doing a quad image + quad_count = random.randint(1, 4) image_size = self.adapter.input_size if is_reg: # we will zero it out in the img embedder @@ -1004,7 +1007,8 @@ class SDTrainer(BaseSDTrainProcess): clip_images, drop=True, is_training=True, - has_been_preprocessed=True + has_been_preprocessed=True, + quad_count=quad_count ) if self.train_config.do_cfg: unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( @@ -1014,13 +1018,15 @@ class SDTrainer(BaseSDTrainProcess): ).detach(), is_training=True, drop=True, - has_been_preprocessed=True + has_been_preprocessed=True, + quad_count=quad_count ) elif has_clip_image: conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( clip_images.detach().to(self.device_torch, dtype=dtype), is_training=True, - has_been_preprocessed=True + has_been_preprocessed=True, + quad_count=quad_count ) if self.train_config.do_cfg: unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( @@ -1030,7 +1036,8 @@ class SDTrainer(BaseSDTrainProcess): ).detach(), is_training=True, drop=True, - has_been_preprocessed=True + has_been_preprocessed=True, + quad_count=quad_count ) else: raise ValueError("Adapter images now must be loaded with dataloader or be a reg image") @@ -1152,7 +1159,8 @@ class SDTrainer(BaseSDTrainProcess): ) # check if nan if torch.isnan(loss): - raise ValueError("loss is nan") + print("loss is nan") + loss = torch.zeros_like(loss).requires_grad_(True) with self.timer('backward'): # todo we have multiplier seperated. works for now as res are not in same batch, but need to change diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index ab27bd2e..70ae28b2 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -2,6 +2,7 @@ import copy import glob import inspect import json +import random import shutil from collections import OrderedDict import os @@ -423,7 +424,7 @@ class BaseSDTrainProcess(BaseTrainProcess): adapter_name += '_t2i' elif self.adapter_config.type == 'clip': adapter_name += '_clip' - elif self.adapter_config.type == 'ip': + elif self.adapter_config.type.startswith('ip'): adapter_name += '_ip' else: adapter_name += '_adapter' @@ -444,7 +445,8 @@ class BaseSDTrainProcess(BaseTrainProcess): state_dict, output_file=file_path, meta=save_meta, - dtype=get_torch_dtype(self.save_config.dtype) + dtype=get_torch_dtype(self.save_config.dtype), + direct_save=self.adapter_config.train_only_image_encoder ) else: if self.save_config.save_format == "diffusers": @@ -1010,7 +1012,8 @@ class BaseSDTrainProcess(BaseTrainProcess): loaded_state_dict = load_ip_adapter_model( latest_save_path, self.device, - dtype=dtype + dtype=dtype, + direct_load=self.adapter_config.train_only_image_encoder ) self.adapter.load_state_dict(loaded_state_dict) else: @@ -1146,14 +1149,11 @@ class BaseSDTrainProcess(BaseTrainProcess): self.snr_gos.scale.data = torch.tensor(json_data['scale'], device=self.device_torch) self.snr_gos.gamma.data = torch.tensor(json_data['gamma'], device=self.device_torch) - # load the adapters before the dataset as they may use the clip encoders - if self.adapter_config is not None: - self.setup_adapter() flush() if not self.is_fine_tuning: if self.network_config is not None: # TODO should we completely switch to LycorisSpecialNetwork? - network_kwargs = {} + network_kwargs = self.network_config.network_kwargs is_lycoris = False is_lorm = self.network_config.type.lower() == 'lorm' # default to LoCON if there are any conv layers or if it is named @@ -1279,12 +1279,16 @@ class BaseSDTrainProcess(BaseTrainProcess): flush() if self.adapter_config is not None: - # self.setup_adapter() - # set trainable params - params.append({ - 'params': self.adapter.parameters(), - 'lr': self.train_config.adapter_lr - }) + self.setup_adapter() + if self.adapter_config.train: + # set trainable params + params.append({ + 'params': self.adapter.parameters(), + 'lr': self.train_config.adapter_lr + }) + + if self.train_config.gradient_checkpointing: + self.adapter.enable_gradient_checkpointing() flush() params = self.load_additional_training_modules(params) @@ -1306,7 +1310,7 @@ class BaseSDTrainProcess(BaseTrainProcess): refiner_lr=self.train_config.refiner_lr, ) # we may be using it for prompt injections - if self.adapter_config is not None: + if self.adapter_config is not None and self.adapter is None: self.setup_adapter() flush() ### HOOK ### @@ -1379,7 +1383,7 @@ class BaseSDTrainProcess(BaseTrainProcess): # sample first if self.train_config.skip_first_sample: self.print("Skipping first sample due to config setting") - elif self.step_num <= 1: + elif self.step_num <= 1 or self.train_config.force_first_sample: self.print("Generating baseline samples before training") self.sample(self.step_num) @@ -1422,6 +1426,9 @@ class BaseSDTrainProcess(BaseTrainProcess): start_step_num = self.step_num did_first_flush = False for step in range(start_step_num, self.train_config.steps): + if self.train_config.do_random_cfg: + self.train_config.do_cfg = True + self.train_config.cfg_scale = value_map(random.random(), 0, 1, 1.0, self.train_config.max_cfg_scale) self.step_num = step # default to true so various things can turn it off self.is_grad_accumulation_step = True diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 284a3227..de9f6eba 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -113,6 +113,7 @@ class NetworkConfig: self.linear_alpha: float = kwargs.get('linear_alpha', self.alpha) self.conv_alpha: float = kwargs.get('conv_alpha', self.conv) self.dropout: Union[float, None] = kwargs.get('dropout', None) + self.network_kwargs: dict = kwargs.get('network_kwargs', {}) self.lorm_config: Union[LoRMConfig, None] = None lorm = kwargs.get('lorm', None) @@ -153,10 +154,14 @@ class AdapterConfig: self.num_tokens: int = num_tokens self.train_image_encoder: bool = kwargs.get('train_image_encoder', False) + self.train_only_image_encoder: bool = kwargs.get('train_only_image_encoder', False) + if self.train_only_image_encoder: + self.train_image_encoder = True self.image_encoder_arch: str = kwargs.get('image_encoder_arch', 'clip') # clip vit vit_hybrid, safe self.safe_reducer_channels: int = kwargs.get('safe_reducer_channels', 512) self.safe_channels: int = kwargs.get('safe_channels', 2048) self.safe_tokens: int = kwargs.get('safe_tokens', 8) + self.quad_image: bool = kwargs.get('quad_image', False) # clip vision self.trigger = kwargs.get('trigger', 'tri993r') @@ -211,6 +216,7 @@ class TrainConfig: self.learnable_snr_gos = kwargs.get('learnable_snr_gos', False) self.noise_offset = kwargs.get('noise_offset', 0.0) self.skip_first_sample = kwargs.get('skip_first_sample', False) + self.force_first_sample = kwargs.get('force_first_sample', False) self.gradient_checkpointing = kwargs.get('gradient_checkpointing', True) self.weight_jitter = kwargs.get('weight_jitter', 0.0) self.merge_network_on_save = kwargs.get('merge_network_on_save', False) @@ -275,7 +281,9 @@ class TrainConfig: self.dynamic_noise_offset = kwargs.get('dynamic_noise_offset', False) self.do_cfg = kwargs.get('do_cfg', False) + self.do_random_cfg = kwargs.get('do_random_cfg', False) self.cfg_scale = kwargs.get('cfg_scale', 1.0) + self.max_cfg_scale = kwargs.get('max_cfg_scale', self.cfg_scale) class ModelConfig: diff --git a/toolkit/custom_adapter.py b/toolkit/custom_adapter.py index 0aef07b4..810f7d67 100644 --- a/toolkit/custom_adapter.py +++ b/toolkit/custom_adapter.py @@ -7,6 +7,7 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from toolkit.models.clip_fusion import CLIPFusionModule from toolkit.models.clip_pre_processor import CLIPImagePreProcessor +from toolkit.models.ilora import InstantLoRAModule from toolkit.paths import REPOS_ROOT from toolkit.photomaker import PhotoMakerIDEncoder, FuseModule, PhotoMakerCLIPEncoder from toolkit.saving import load_ip_adapter_model @@ -74,6 +75,7 @@ class CustomAdapter(torch.nn.Module): self.clip_image_processor = self.image_processor self.clip_fusion_module: CLIPFusionModule = None + self.ilora_module: InstantLoRAModule = None self.setup_adapter() @@ -106,6 +108,15 @@ class CustomAdapter(torch.nn.Module): vision_hidden_size=self.vision_encoder.config.hidden_size, vision_tokens=vision_tokens ) + elif self.adapter_type == 'ilora': + vision_tokens = ((self.vision_encoder.config.image_size // self.vision_encoder.config.patch_size) ** 2) + if self.config.image_encoder_arch == 'clip': + vision_tokens = vision_tokens + 1 + self.ilora_module = InstantLoRAModule( + vision_tokens=vision_tokens, + vision_hidden_size=self.vision_encoder.config.hidden_size, + sd=self.sd_ref() + ) else: raise ValueError(f"unknown adapter type: {self.adapter_type}") @@ -283,6 +294,9 @@ class CustomAdapter(torch.nn.Module): if 'fuse_module' in state_dict: self.fuse_module.load_state_dict(state_dict['fuse_module'], strict=strict) + if 'ilora' in state_dict: + self.ilora_module.load_state_dict(state_dict['ilora'], strict=strict) + pass def state_dict(self) -> OrderedDict: @@ -301,6 +315,11 @@ class CustomAdapter(torch.nn.Module): state_dict["vision_encoder"] = self.vision_encoder.state_dict() state_dict["clip_fusion"] = self.clip_fusion_module.state_dict() return state_dict + elif self.adapter_type == 'ilora': + if self.config.train_image_encoder: + state_dict["vision_encoder"] = self.vision_encoder.state_dict() + state_dict["ilora"] = self.ilora_module.state_dict() + return state_dict else: raise NotImplementedError @@ -309,7 +328,7 @@ class CustomAdapter(torch.nn.Module): prompt: Union[List[str], str], is_unconditional: bool = False, ): - if self.adapter_type == 'clip_fusion': + if self.adapter_type == 'clip_fusion' or self.adapter_type == 'ilora': return prompt elif self.adapter_type == 'photo_maker': if is_unconditional: @@ -408,7 +427,7 @@ class CustomAdapter(torch.nn.Module): has_been_preprocessed=False, is_unconditional=False ) -> PromptEmbeds: - if self.adapter_type == 'photo_maker' or self.adapter_type == 'clip_fusion': + if self.adapter_type == 'photo_maker' or self.adapter_type == 'clip_fusion' or self.adapter_type == 'ilora': if is_unconditional: # we dont condition the negative embeds for photo maker return prompt_embeds.clone() @@ -459,7 +478,7 @@ class CustomAdapter(torch.nn.Module): self.token_mask ) return prompt_embeds - elif self.adapter_type == 'clip_fusion': + elif self.adapter_type == 'clip_fusion' or self.adapter_type == 'ilora': with torch.set_grad_enabled(is_training): if is_training and self.config.train_image_encoder: self.vision_encoder.train() @@ -480,11 +499,17 @@ class CustomAdapter(torch.nn.Module): if not is_training or not self.config.train_image_encoder: img_embeds = img_embeds.detach() - prompt_embeds.text_embeds = self.clip_fusion_module( - prompt_embeds.text_embeds, - img_embeds - ) - return prompt_embeds + if self.adapter_type == 'ilora': + self.ilora_module.img_embeds = img_embeds + + return prompt_embeds + else: + + prompt_embeds.text_embeds = self.clip_fusion_module( + prompt_embeds.text_embeds, + img_embeds + ) + return prompt_embeds else: @@ -499,5 +524,9 @@ class CustomAdapter(torch.nn.Module): yield from self.clip_fusion_module.parameters(recurse) if self.config.train_image_encoder: yield from self.vision_encoder.parameters(recurse) + elif self.config.type == 'ilora': + yield from self.ilora_module.parameters(recurse) + if self.config.train_image_encoder: + yield from self.vision_encoder.parameters(recurse) else: raise NotImplementedError diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 3a82511e..ffd13156 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -655,6 +655,21 @@ class ClipImageFileItemDTOMixin: else: self.clip_image_tensor = transforms.ToTensor()(img) + # random crop + # if self.dataset_config.clip_image_random_crop: + # # crop up to 20% on all sides. Keep is square + # crop_percent = random.randint(0, 20) / 100 + # crop_width = int(self.clip_image_tensor.shape[2] * crop_percent) + # crop_height = int(self.clip_image_tensor.shape[1] * crop_percent) + # crop_left = random.randint(0, crop_width) + # crop_top = random.randint(0, crop_height) + # crop_right = self.clip_image_tensor.shape[2] - crop_width - crop_left + # crop_bottom = self.clip_image_tensor.shape[1] - crop_height - crop_top + # if len(self.clip_image_tensor.shape) == 3: + # self.clip_image_tensor = self.clip_image_tensor[:, crop_top:-crop_bottom, crop_left:-crop_right] + # elif len(self.clip_image_tensor.shape) == 4: + # self.clip_image_tensor = self.clip_image_tensor[:, :, crop_top:-crop_bottom, crop_left:-crop_right] + if self.clip_image_processor is not None: # run it tensors_0_1 = self.clip_image_tensor.to(dtype=torch.float16) diff --git a/toolkit/ip_adapter.py b/toolkit/ip_adapter.py index b7bb7034..fce13e90 100644 --- a/toolkit/ip_adapter.py +++ b/toolkit/ip_adapter.py @@ -1,3 +1,5 @@ +import random + import torch import sys @@ -39,6 +41,9 @@ from transformers import ViTHybridImageProcessor, ViTHybridForImageClassificatio from transformers import ViTFeatureExtractor, ViTForImageClassification +# gradient checkpointing +from torch.utils.checkpoint import checkpoint + import torch.nn.functional as F @@ -166,6 +171,8 @@ class IPAdapter(torch.nn.Module): self.device = self.sd_ref().unet.device self.preprocessor: Optional[CLIPImagePreProcessor] = None self.input_size = 224 + self.clip_noise_zero = True + self.unconditional: torch.Tensor = None if self.config.image_encoder_arch == 'clip' or self.config.image_encoder_arch == 'clip+': try: self.clip_image_processor = CLIPImageProcessor.from_pretrained(adapter_config.image_encoder_path) @@ -236,6 +243,16 @@ class IPAdapter(torch.nn.Module): self.input_size = self.image_encoder.config.image_size + if self.config.quad_image: # 4x4 image + # self.clip_image_processor.config + # We do a 3x downscale of the image, so we need to adjust the input size + preprocessor_input_size = self.image_encoder.config.image_size * 2 + + # update the preprocessor so images come in at the right size + self.clip_image_processor.size['shortest_edge'] = preprocessor_input_size + self.clip_image_processor.crop_size['height'] = preprocessor_input_size + self.clip_image_processor.crop_size['width'] = preprocessor_input_size + if self.config.image_encoder_arch == 'clip+': # self.clip_image_processor.config # We do a 3x downscale of the image, so we need to adjust the input size @@ -349,6 +366,15 @@ class IPAdapter(torch.nn.Module): self.image_encoder.train() self.image_encoder.requires_grad_(True) + # premake a unconditional + zerod = torch.zeros(1, 3, self.input_size, self.input_size, device=self.device, dtype=torch.float16) + self.unconditional = self.clip_image_processor( + images=zerod, + return_tensors="pt", + do_resize=True, + do_rescale=False, + ).pixel_values + def to(self, *args, **kwargs): super().to(*args, **kwargs) self.image_encoder.to(*args, **kwargs) @@ -358,20 +384,23 @@ class IPAdapter(torch.nn.Module): self.preprocessor.to(*args, **kwargs) return self - def load_ip_adapter(self, state_dict: Union[OrderedDict, dict]): - self.image_proj_model.load_state_dict(state_dict["image_proj"]) - ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) - ip_layers.load_state_dict(state_dict["ip_adapter"]) - if self.config.train_image_encoder and 'image_encoder' in state_dict: - self.image_encoder.load_state_dict(state_dict["image_encoder"]) - if self.preprocessor is not None and 'preprocessor' in state_dict: - self.preprocessor.load_state_dict(state_dict["preprocessor"]) + # def load_ip_adapter(self, state_dict: Union[OrderedDict, dict]): + # self.image_proj_model.load_state_dict(state_dict["image_proj"]) + # ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) + # ip_layers.load_state_dict(state_dict["ip_adapter"]) + # if self.config.train_image_encoder and 'image_encoder' in state_dict: + # self.image_encoder.load_state_dict(state_dict["image_encoder"]) + # if self.preprocessor is not None and 'preprocessor' in state_dict: + # self.preprocessor.load_state_dict(state_dict["preprocessor"]) # def load_state_dict(self, state_dict: Union[OrderedDict, dict]): # self.load_ip_adapter(state_dict) def state_dict(self) -> OrderedDict: state_dict = OrderedDict() + if self.config.train_only_image_encoder: + return self.image_encoder.state_dict() + state_dict["image_proj"] = self.image_proj_model.state_dict() state_dict["ip_adapter"] = self.adapter_modules.state_dict() if self.config.train_image_encoder: @@ -402,13 +431,28 @@ class IPAdapter(torch.nn.Module): # clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] # return clip_image_embeds + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.image_encoder.to(*args, **kwargs) + self.image_proj_model.to(*args, **kwargs) + self.adapter_modules.to(*args, **kwargs) + if self.preprocessor is not None: + self.preprocessor.to(*args, **kwargs) + return self def get_clip_image_embeds_from_tensors( self, tensors_0_1: torch.Tensor, drop=False, is_training=False, - has_been_preprocessed=False + has_been_preprocessed=False, + quad_count=4, ) -> torch.Tensor: + if self.sd_ref().unet.device != self.device: + self.to(self.sd_ref().unet.device) + if self.sd_ref().unet.device != self.image_encoder.device: + self.to(self.sd_ref().unet.device) + if not self.config.train: + is_training = False with torch.no_grad(): # on training the clip image is created in the dataloader if not has_been_preprocessed: @@ -417,11 +461,19 @@ class IPAdapter(torch.nn.Module): tensors_0_1 = tensors_0_1.unsqueeze(0) # training tensors are 0 - 1 tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16) + # if images are out of this range throw error if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3: raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format( tensors_0_1.min(), tensors_0_1.max() )) + # unconditional + if drop: + if self.clip_noise_zero: + tensors_0_1 = torch.rand_like(tensors_0_1).detach() + else: + tensors_0_1 = torch.zeros_like(tensors_0_1).detach() + # tensors_0_1 = tensors_0_1 * 0 clip_image = self.clip_image_processor( images=tensors_0_1, return_tensors="pt", @@ -429,10 +481,42 @@ class IPAdapter(torch.nn.Module): do_rescale=False, ).pixel_values else: - clip_image = tensors_0_1 + if drop: + # scale the noise down + if self.clip_noise_zero: + tensors_0_1 = torch.rand_like(tensors_0_1).detach() + else: + tensors_0_1 = torch.zeros_like(tensors_0_1).detach() + # tensors_0_1 = tensors_0_1 * 0 + mean = torch.tensor(self.clip_image_processor.image_mean).to( + self.device, dtype=get_torch_dtype(self.sd_ref().dtype) + ).detach() + std = torch.tensor(self.clip_image_processor.image_std).to( + self.device, dtype=get_torch_dtype(self.sd_ref().dtype) + ).detach() + tensors_0_1 = torch.clip((255. * tensors_0_1), 0, 255).round() / 255.0 + clip_image = (tensors_0_1 - mean.view([1, 3, 1, 1])) / std.view([1, 3, 1, 1]) + + else: + clip_image = tensors_0_1 clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach() - if drop: - clip_image = clip_image * 0 + + if self.config.quad_image: + # split the 4x4 grid and stack on batch + ci1, ci2 = clip_image.chunk(2, dim=2) + ci1, ci3 = ci1.chunk(2, dim=3) + ci2, ci4 = ci2.chunk(2, dim=3) + to_cat = [] + for i, ci in enumerate([ci1, ci2, ci3, ci4]): + if i < quad_count: + to_cat.append(ci) + else: + break + + clip_image = torch.cat(to_cat, dim=0).detach() + + # if drop: + # clip_image = clip_image * 0 with torch.set_grad_enabled(is_training): if is_training: self.image_encoder.train() @@ -457,6 +541,20 @@ class IPAdapter(torch.nn.Module): clip_image_embeds = clip_output.hidden_states[-2] else: clip_image_embeds = clip_output.image_embeds + + if self.config.quad_image: + # get the outputs of the quat + chunks = clip_image_embeds.chunk(quad_count, dim=0) + chunk_sum = torch.zeros_like(chunks[0]) + for chunk in chunks: + chunk_sum = chunk_sum + chunk + # get the mean of them + + clip_image_embeds = chunk_sum / quad_count + + if not is_training: + clip_image_embeds = clip_image_embeds.detach() + return clip_image_embeds # use drop for prompt dropout, or negatives @@ -467,6 +565,9 @@ class IPAdapter(torch.nn.Module): return embeddings def parameters(self, recurse: bool = True) -> Iterator[Parameter]: + if self.config.train_only_image_encoder: + yield from self.image_encoder.parameters(recurse) + return for attn_processor in self.adapter_modules: yield from attn_processor.parameters(recurse) yield from self.image_proj_model.parameters(recurse) @@ -561,17 +662,26 @@ class IPAdapter(torch.nn.Module): def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): strict = False - try: - self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict) - self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=strict) - except Exception as e: - print(e) - print("could not load ip adapter weights, trying to merge in weights") - self.merge_in_weights(state_dict) + if 'ip_adapter' in state_dict: + try: + self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict) + self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=strict) + except Exception as e: + print(e) + print("could not load ip adapter weights, trying to merge in weights") + self.merge_in_weights(state_dict) if self.config.train_image_encoder and 'image_encoder' in state_dict: self.image_encoder.load_state_dict(state_dict["image_encoder"], strict=strict) if self.preprocessor is not None and 'preprocessor' in state_dict: self.preprocessor.load_state_dict(state_dict["preprocessor"], strict=strict) + if self.config.train_only_image_encoder and 'ip_adapter' not in state_dict: + # we are loading pure clip weights. + self.image_encoder.load_state_dict(state_dict, strict=strict) + + def enable_gradient_checkpointing(self): - self.image_encoder.gradient_checkpointing = True + if hasattr(self.image_encoder, "enable_gradient_checkpointing"): + self.image_encoder.enable_gradient_checkpointing() + elif hasattr(self.image_encoder, 'gradient_checkpointing'): + self.image_encoder.gradient_checkpointing = True diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index 2ad5f3f0..2fd18ab5 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -114,9 +114,9 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): # UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"] # UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "ResnetBlock2D"] - UNET_TARGET_REPLACE_MODULE = ["''UNet2DConditionModel''"] + UNET_TARGET_REPLACE_MODULE = ["UNet2DConditionModel"] # UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] - UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["'UNet2DConditionModel'"] + UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["UNet2DConditionModel"] TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] LORA_PREFIX_UNET = "lora_unet" LORA_PREFIX_TEXT_ENCODER = "lora_te" @@ -155,6 +155,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): is_lorm: bool = False, ignore_if_contains = None, parameter_threshold: float = 0.0, + attn_only: bool = False, target_lin_modules=LoRANetwork.UNET_TARGET_REPLACE_MODULE, target_conv_modules=LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3, **kwargs @@ -243,6 +244,10 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): # for child_name, child_module in module.named_modules(): is_linear = module_name == 'LoRACompatibleLinear' is_conv2d = module_name == 'LoRACompatibleConv' + # check if attn in name + is_attention = "attentions" in name + if not is_attention and attn_only: + continue if is_linear and self.lora_dim is None: continue diff --git a/toolkit/metadata.py b/toolkit/metadata.py index bf969f09..4a5c36ad 100644 --- a/toolkit/metadata.py +++ b/toolkit/metadata.py @@ -23,6 +23,8 @@ def get_meta_for_safetensors(meta: OrderedDict, name=None, add_software_info=Tru # if not float, int, bool, or str, convert to json string if not isinstance(value, str): save_meta[key] = json.dumps(value) + # add the pt format + save_meta["format"] = "pt" return save_meta diff --git a/toolkit/models/ilora.py b/toolkit/models/ilora.py new file mode 100644 index 00000000..9352d164 --- /dev/null +++ b/toolkit/models/ilora.py @@ -0,0 +1,104 @@ +import weakref + +import torch +import torch.nn as nn +from typing import TYPE_CHECKING +from toolkit.models.clip_fusion import ZipperBlock + +if TYPE_CHECKING: + from toolkit.lora_special import LoRAModule + from toolkit.stable_diffusion_model import StableDiffusion + + +class InstantLoRAMidModule(torch.nn.Module): + def __init__( + self, + dim: int, + vision_tokens: int, + vision_hidden_size: int, + lora_module: 'LoRAModule', + instant_lora_module: 'InstantLoRAModule' + ): + super(InstantLoRAMidModule, self).__init__() + self.dim = dim + self.vision_tokens = vision_tokens + self.vision_hidden_size = vision_hidden_size + self.lora_module_ref = weakref.ref(lora_module) + self.instant_lora_module_ref = weakref.ref(instant_lora_module) + + self.zip = ZipperBlock( + in_size=self.vision_hidden_size, + in_tokens=self.vision_tokens, + out_size=self.dim, + out_tokens=1, + hidden_size=self.dim, + hidden_tokens=self.vision_tokens + ) + + def forward(self, x, *args, **kwargs): + # get the vector + img_embeds = self.instant_lora_module_ref().img_embeds + # project it + scaler = self.zip(img_embeds) # (batch_size, 1, dim) + + # remove the channel dim + scaler = scaler.squeeze(1) + + # double up if batch is 2x the size on x (cfg) + if x.shape[0] // 2 == scaler.shape[0]: + scaler = torch.cat([scaler, scaler], dim=0) + + # multiply it by the scaler + try: + # reshape if needed + if len(x.shape) == 3: + scaler = scaler.unsqueeze(1) + x = x * scaler + except Exception as e: + print(e) + print(x.shape) + print(scaler.shape) + raise e + # apply tanh to limit values to -1 to 1 + scaler = torch.tanh(scaler) + return x * scaler + + +class InstantLoRAModule(torch.nn.Module): + def __init__( + self, + vision_hidden_size: int, + vision_tokens: int, + sd: 'StableDiffusion' + ): + super(InstantLoRAModule, self).__init__() + self.linear = torch.nn.Linear(2, 1) + self.sd_ref = weakref.ref(sd) + self.dim = sd.network.lora_dim + self.vision_hidden_size = vision_hidden_size + self.vision_tokens = vision_tokens + + # stores the projection vector. Grabbed by modules + self.img_embeds: torch.Tensor = None + + # disable merging in. It is slower on inference + self.sd_ref().network.can_merge_in = False + + self.ilora_modules = torch.nn.ModuleList() + + lora_modules = self.sd_ref().network.get_all_modules() + + for lora_module in lora_modules: + # add a new mid module that will take the original forward and add a vector to it + # this will be used to add the vector to the original forward + mid_module = InstantLoRAMidModule(self.dim, self.vision_tokens, self.vision_hidden_size, lora_module, self) + + self.ilora_modules.append(mid_module) + # replace the LoRA lora_mid + lora_module.lora_mid = mid_module.forward + + # add a new mid module that will take the original forward and add a vector to it + # this will be used to add the vector to the original forward + + def forward(self, x): + return self.linear(x) diff --git a/toolkit/saving.py b/toolkit/saving.py index 597f4620..e2a90abb 100644 --- a/toolkit/saving.py +++ b/toolkit/saving.py @@ -215,12 +215,17 @@ def save_ip_adapter_from_diffusers( output_file: str, meta: 'OrderedDict', dtype=get_torch_dtype('fp16'), + direct_save: bool = False ): # todo: test compatibility with non diffusers + converted_state_dict = OrderedDict() for module_name, state_dict in combined_state_dict.items(): - for key, value in state_dict.items(): - converted_state_dict[f"{module_name}.{key}"] = value.detach().to('cpu', dtype=dtype) + if direct_save: + converted_state_dict[module_name] = state_dict.detach().to('cpu', dtype=dtype) + else: + for key, value in state_dict.items(): + converted_state_dict[f"{module_name}.{key}"] = value.detach().to('cpu', dtype=dtype) # make sure parent folder exists os.makedirs(os.path.dirname(output_file), exist_ok=True) @@ -230,12 +235,15 @@ def save_ip_adapter_from_diffusers( def load_ip_adapter_model( path_to_file, device: Union[str] = 'cpu', - dtype: torch.dtype = torch.float32 + dtype: torch.dtype = torch.float32, + direct_load: bool = False ): # check if it is safetensors or checkpoint if path_to_file.endswith('.safetensors'): raw_state_dict = load_file(path_to_file, device) combined_state_dict = OrderedDict() + if direct_load: + return raw_state_dict for combo_key, value in raw_state_dict.items(): key_split = combo_key.split('.') module_name = key_split.pop(0)