diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 066969a1..1647d513 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -225,52 +225,51 @@ class SDTrainer(BaseSDTrainProcess): noise_pred_norm = torch.linalg.vector_norm(noise_pred, ord=2, dim=(1, 2, 3), keepdim=True) noise_pred = noise_pred * (noise_norm / noise_pred_norm) - if self.train_config.correct_pred_norm and not is_reg: - with torch.no_grad(): - # this only works if doing a prior pred - if prior_pred is not None: - prior_mean = prior_pred.mean([2,3], keepdim=True) - prior_std = prior_pred.std([2,3], keepdim=True) - noise_mean = noise_pred.mean([2,3], keepdim=True) - noise_std = noise_pred.std([2,3], keepdim=True) + target = None + if self.train_config.correct_pred_norm or (self.train_config.inverted_mask_prior and prior_pred is not None and has_mask): + if self.train_config.correct_pred_norm and not is_reg: + with torch.no_grad(): + # this only works if doing a prior pred + if prior_pred is not None: + prior_mean = prior_pred.mean([2,3], keepdim=True) + prior_std = prior_pred.std([2,3], keepdim=True) + noise_mean = noise_pred.mean([2,3], keepdim=True) + noise_std = noise_pred.std([2,3], keepdim=True) - mean_adjust = prior_mean - noise_mean - std_adjust = prior_std - noise_std + mean_adjust = prior_mean - noise_mean + std_adjust = prior_std - noise_std - mean_adjust = mean_adjust * self.train_config.correct_pred_norm_multiplier - std_adjust = std_adjust * self.train_config.correct_pred_norm_multiplier + mean_adjust = mean_adjust * self.train_config.correct_pred_norm_multiplier + std_adjust = std_adjust * self.train_config.correct_pred_norm_multiplier - target_mean = noise_mean + mean_adjust - target_std = noise_std + std_adjust + target_mean = noise_mean + mean_adjust + target_std = noise_std + std_adjust - eps = 1e-5 + eps = 1e-5 + # match the noise to the prior + noise = (noise - noise_mean) / (noise_std + eps) + noise = noise * (target_std + eps) + target_mean + noise = noise.detach() - # adjust the noise target to match the current knowledge of the model - # noise_mean, noise_std = get_mean_std(noise) - # match the noise to the prior - noise = (noise - noise_mean) / (noise_std + eps) - noise = noise * (target_std + eps) + target_mean - noise = noise.detach() + if self.train_config.inverted_mask_prior and prior_pred is not None and has_mask: + assert not self.train_config.train_turbo + # we need to make the noise prediction be a masked blending of noise and prior_pred + stretched_mask_multiplier = value_map( + mask_multiplier, + batch.file_items[0].dataset_config.mask_min_value, + 1.0, + 0.0, + 1.0 + ) - if self.train_config.inverted_mask_prior and prior_pred is not None and has_mask: - assert not self.train_config.train_turbo - # we need to make the noise prediction be a masked blending of noise and prior_pred - stretched_mask_multiplier = value_map( - mask_multiplier, - batch.file_items[0].dataset_config.mask_min_value, - 1.0, - 0.0, - 1.0 - ) + prior_mask_multiplier = 1.0 - stretched_mask_multiplier - prior_mask_multiplier = 1.0 - stretched_mask_multiplier - - # target_mask_multiplier = mask_multiplier - # mask_multiplier = 1.0 - target = noise - # target = (noise * mask_multiplier) + (prior_pred * prior_mask_multiplier) - # set masked multiplier to 1.0 so we dont double apply it - # mask_multiplier = 1.0 + # target_mask_multiplier = mask_multiplier + # mask_multiplier = 1.0 + target = noise + # target = (noise * mask_multiplier) + (prior_pred * prior_mask_multiplier) + # set masked multiplier to 1.0 so we dont double apply it + # mask_multiplier = 1.0 elif prior_pred is not None: assert not self.train_config.train_turbo # matching adapter prediction @@ -281,6 +280,9 @@ class SDTrainer(BaseSDTrainProcess): else: target = noise + if target is None: + target = noise + pred = noise_pred if self.train_config.train_turbo: @@ -360,6 +362,13 @@ class SDTrainer(BaseSDTrainProcess): loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma) loss = loss.mean() + + # check for additional losses + if self.adapter is not None and hasattr(self.adapter, "additional_loss") and self.adapter.additional_loss is not None: + + loss = loss + self.adapter.additional_loss.mean() + self.adapter.additional_loss = None + return loss def preprocess_batch(self, batch: 'DataLoaderBatchDTO'): @@ -677,6 +686,7 @@ class SDTrainer(BaseSDTrainProcess): batch: 'DataLoaderBatchDTO', noise: torch.Tensor, unconditional_embeds: Optional[PromptEmbeds] = None, + conditioned_prompts=None, **kwargs ): # todo for embeddings, we need to run without trigger words @@ -980,6 +990,17 @@ class SDTrainer(BaseSDTrainProcess): # it will be injected into the tokenizer when called self.adapter(conditional_clip_embeds) + # do the custom adapter after the prior prediction + if self.adapter and isinstance(self.adapter, CustomAdapter) and has_clip_image: + quad_count = random.randint(1, 4) + self.adapter.train() + self.adapter.trigger_pre_te( + tensors_0_1=clip_images, + is_training=True, + has_been_preprocessed=True, + quad_count=quad_count + ) + with self.timer('encode_prompt'): unconditional_embeds = None if grad_on_text_encoder: @@ -1140,6 +1161,7 @@ class SDTrainer(BaseSDTrainProcess): unconditional_clip_embeds = unconditional_clip_embeds.detach() with self.timer('encode_adapter'): + self.adapter.train() conditional_embeds = self.adapter(conditional_embeds.detach(), conditional_clip_embeds) if self.train_config.do_cfg: unconditional_embeds = self.adapter(unconditional_embeds.detach(), @@ -1170,8 +1192,10 @@ class SDTrainer(BaseSDTrainProcess): if self.train_config.inverted_mask_prior and batch.mask_tensor is not None: do_inverted_masked_prior = True + do_correct_pred_norm_prior = self.train_config.correct_pred_norm + if (( - has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction or do_reg_prior or do_inverted_masked_prior): + has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction or do_reg_prior or do_inverted_masked_prior or self.train_config.correct_pred_norm): with self.timer('prior predict'): prior_pred = self.get_prior_prediction( noisy_latents=noisy_latents, @@ -1182,8 +1206,12 @@ class SDTrainer(BaseSDTrainProcess): pred_kwargs=pred_kwargs, noise=noise, batch=batch, - unconditional_embeds=unconditional_embeds - ).detach() + unconditional_embeds=unconditional_embeds, + conditioned_prompts=conditioned_prompts + ) + if prior_pred is not None: + prior_pred = prior_pred.detach() + # do the custom adapter after the prior prediction if self.adapter and isinstance(self.adapter, CustomAdapter) and has_clip_image: diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 19861c4d..6196ba68 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -130,6 +130,7 @@ class NetworkConfig: AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker'] +CLIPLayer = Literal['penultimate_hidden_states', 'image_embeds', 'last_hidden_state'] class AdapterConfig: def __init__(self, **kwargs): @@ -169,6 +170,13 @@ class AdapterConfig: self.class_names = kwargs.get('class_names', []) + self.clip_layer: CLIPLayer = kwargs.get('clip_layer', None) + if self.clip_layer is None: + if self.type.startswith('ip+'): + self.clip_layer = 'penultimate_hidden_states' + else: + self.clip_layer = 'last_hidden_state' + class EmbeddingConfig: def __init__(self, **kwargs): diff --git a/toolkit/custom_adapter.py b/toolkit/custom_adapter.py index 3e02f497..c1c73909 100644 --- a/toolkit/custom_adapter.py +++ b/toolkit/custom_adapter.py @@ -438,7 +438,10 @@ class CustomAdapter(torch.nn.Module): is_unconditional=False, quad_count=4, ) -> PromptEmbeds: - if self.adapter_type == 'photo_maker' or self.adapter_type == 'clip_fusion' or self.adapter_type == 'ilora': + if self.adapter_type == 'ilora': + return prompt_embeds + + if self.adapter_type == 'photo_maker' or self.adapter_type == 'clip_fusion': if is_unconditional: # we dont condition the negative embeds for photo maker return prompt_embeds.clone() @@ -503,7 +506,7 @@ class CustomAdapter(torch.nn.Module): self.token_mask ) return prompt_embeds - elif self.adapter_type == 'clip_fusion' or self.adapter_type == 'ilora': + elif self.adapter_type == 'clip_fusion': with torch.set_grad_enabled(is_training): if is_training and self.config.train_image_encoder: self.vision_encoder.train() @@ -535,22 +538,96 @@ class CustomAdapter(torch.nn.Module): if not is_training or not self.config.train_image_encoder: img_embeds = img_embeds.detach() - if self.adapter_type == 'ilora': - self.ilora_module.img_embeds = img_embeds - return prompt_embeds - else: - - prompt_embeds.text_embeds = self.clip_fusion_module( - prompt_embeds.text_embeds, - img_embeds - ) - return prompt_embeds + prompt_embeds.text_embeds = self.clip_fusion_module( + prompt_embeds.text_embeds, + img_embeds + ) + return prompt_embeds else: raise NotImplementedError + def trigger_pre_te( + self, + tensors_0_1: torch.Tensor, + is_training=False, + has_been_preprocessed=False, + quad_count=4, + ) -> PromptEmbeds: + if self.adapter_type == 'ilora': + with torch.no_grad(): + # on training the clip image is created in the dataloader + if not has_been_preprocessed: + # tensors should be 0-1 + if tensors_0_1.ndim == 3: + tensors_0_1 = tensors_0_1.unsqueeze(0) + # training tensors are 0 - 1 + tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16) + # if images are out of this range throw error + if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3: + raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format( + tensors_0_1.min(), tensors_0_1.max() + )) + clip_image = self.image_processor( + images=tensors_0_1, + return_tensors="pt", + do_resize=True, + do_rescale=False, + ).pixel_values + else: + clip_image = tensors_0_1 + clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach() + + if self.config.quad_image: + # split the 4x4 grid and stack on batch + ci1, ci2 = clip_image.chunk(2, dim=2) + ci1, ci3 = ci1.chunk(2, dim=3) + ci2, ci4 = ci2.chunk(2, dim=3) + to_cat = [] + for i, ci in enumerate([ci1, ci2, ci3, ci4]): + if i < quad_count: + to_cat.append(ci) + else: + break + + clip_image = torch.cat(to_cat, dim=0).detach() + + if self.adapter_type == 'ilora': + with torch.set_grad_enabled(is_training): + if is_training and self.config.train_image_encoder: + self.vision_encoder.train() + clip_image = clip_image.requires_grad_(True) + id_embeds = self.vision_encoder( + clip_image, + output_hidden_states=True, + ) + else: + with torch.no_grad(): + self.vision_encoder.eval() + id_embeds = self.vision_encoder( + clip_image, output_hidden_states=True + ) + + img_embeds = id_embeds['last_hidden_state'] + + if self.config.quad_image: + # get the outputs of the quat + chunks = img_embeds.chunk(quad_count, dim=0) + chunk_sum = torch.zeros_like(chunks[0]) + for chunk in chunks: + chunk_sum = chunk_sum + chunk + # get the mean of them + + img_embeds = chunk_sum / quad_count + + + if not is_training or not self.config.train_image_encoder: + img_embeds = img_embeds.detach() + + self.ilora_module.img_embeds = img_embeds + def parameters(self, recurse: bool = True) -> Iterator[Parameter]: if self.config.type == 'photo_maker': yield from self.fuse_module.parameters(recurse) diff --git a/toolkit/ip_adapter.py b/toolkit/ip_adapter.py index 6ad675b4..a08d523f 100644 --- a/toolkit/ip_adapter.py +++ b/toolkit/ip_adapter.py @@ -5,6 +5,7 @@ import sys from PIL import Image from torch.nn import Parameter +from torch.nn.modules.module import T from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from toolkit.models.clip_pre_processor import CLIPImagePreProcessor @@ -173,6 +174,7 @@ class IPAdapter(torch.nn.Module): self.input_size = 224 self.clip_noise_zero = True self.unconditional: torch.Tensor = None + self.additional_loss = None if self.config.image_encoder_arch == 'clip' or self.config.image_encoder_arch == 'clip+': try: self.clip_image_processor = CLIPImageProcessor.from_pretrained(adapter_config.image_encoder_path) @@ -451,10 +453,7 @@ class IPAdapter(torch.nn.Module): ): with torch.no_grad(): device = self.sd_ref().unet.device - if self.config.type.startswith('ip+'): - clip_image_embeds = torch.cat([x['penultimate_hidden_states'] for x in image_embeds_list], dim=0) - else: - clip_image_embeds = torch.cat([x['image_embeds'] for x in image_embeds_list], dim=0) + clip_image_embeds = torch.cat([x[self.config.clip_layer] for x in image_embeds_list], dim=0) if self.config.quad_image: # get the outputs of the quat @@ -548,7 +547,7 @@ class IPAdapter(torch.nn.Module): # if drop: # clip_image = clip_image * 0 with torch.set_grad_enabled(is_training): - if is_training: + if is_training and self.config.train_image_encoder: self.image_encoder.train() clip_image = clip_image.requires_grad_(True) if self.preprocessor is not None: @@ -565,16 +564,39 @@ class IPAdapter(torch.nn.Module): clip_image, output_hidden_states=True ) - if self.config.type.startswith('ip+'): + if self.config.clip_layer == 'penultimate_hidden_states': # they skip last layer for ip+ # https://github.com/tencent-ailab/IP-Adapter/blob/f4b6742db35ea6d81c7b829a55b0a312c7f5a677/tutorial_train_plus.py#L403C26-L403C26 clip_image_embeds = clip_output.hidden_states[-2] + elif self.config.clip_layer == 'last_hidden_state': + clip_image_embeds = clip_output.hidden_states[-1] else: clip_image_embeds = clip_output.image_embeds if self.config.quad_image: # get the outputs of the quat chunks = clip_image_embeds.chunk(quad_count, dim=0) + if self.config.train_image_encoder and is_training: + # perform a loss across all chunks this will teach the vision encoder to + # identify similarities in our pairs of images and ignore things that do not make them similar + num_losses = 0 + total_loss = None + for chunk in chunks: + for chunk2 in chunks: + if chunk is not chunk2: + loss = F.mse_loss(chunk, chunk2) + if total_loss is None: + total_loss = loss + else: + total_loss = total_loss + loss + num_losses += 1 + if total_loss is not None: + total_loss = total_loss / num_losses + total_loss = total_loss * 1e-2 + if self.additional_loss is not None: + total_loss = total_loss + self.additional_loss + self.additional_loss = total_loss + chunk_sum = torch.zeros_like(chunks[0]) for chunk in chunks: chunk_sum = chunk_sum + chunk @@ -582,7 +604,7 @@ class IPAdapter(torch.nn.Module): clip_image_embeds = chunk_sum / quad_count - if not is_training: + if not is_training or not self.config.train_image_encoder: clip_image_embeds = clip_image_embeds.detach() return clip_image_embeds @@ -594,6 +616,17 @@ class IPAdapter(torch.nn.Module): embeddings.text_embeds = torch.cat([embeddings.text_embeds, image_prompt_embeds], dim=1) return embeddings + + def train(self: T, mode: bool = True) -> T: + if self.config.train_image_encoder: + self.image_encoder.train(mode) + if not self.config.train_only_image_encoder: + for attn_processor in self.adapter_modules: + attn_processor.train(mode) + if self.image_proj_model is not None: + self.image_proj_model.train(mode) + return super().train(mode) + def parameters(self, recurse: bool = True) -> Iterator[Parameter]: if self.config.train_only_image_encoder: yield from self.image_encoder.parameters(recurse) diff --git a/toolkit/models/ilora.py b/toolkit/models/ilora.py index 9352d164..968f28a9 100644 --- a/toolkit/models/ilora.py +++ b/toolkit/models/ilora.py @@ -53,7 +53,6 @@ class InstantLoRAMidModule(torch.nn.Module): # reshape if needed if len(x.shape) == 3: scaler = scaler.unsqueeze(1) - x = x * scaler except Exception as e: print(e) print(x.shape) diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py index e92e835e..2b7519d3 100644 --- a/toolkit/network_mixins.py +++ b/toolkit/network_mixins.py @@ -354,6 +354,7 @@ class ToolkitNetworkMixin: self.is_ssd = is_ssd self.is_vega = is_vega self.is_v2 = is_v2 + self.is_v1 = not is_v2 and not is_sdxl and not is_ssd and not is_vega self.is_merged_in = False self.is_lorm = is_lorm self.network_config: NetworkConfig = network_config @@ -361,7 +362,7 @@ class ToolkitNetworkMixin: self.lorm_train_mode: Literal['local', None] = None self.can_merge_in = not is_lorm - def get_keymap(self: Network): + def get_keymap(self: Network, force_weight_mapping=False): use_weight_mapping = False if self.is_ssd: @@ -377,6 +378,9 @@ class ToolkitNetworkMixin: else: keymap_tail = 'sd1' # todo double check this + # use_weight_mapping = True + + if force_weight_mapping: use_weight_mapping = True # load keymap @@ -440,9 +444,9 @@ class ToolkitNetworkMixin: else: torch.save(save_dict, file) - def load_weights(self: Network, file): + def load_weights(self: Network, file, force_weight_mapping=False): # allows us to save and load to and from ldm weights - keymap = self.get_keymap() + keymap = self.get_keymap(force_weight_mapping) keymap = {} if keymap is None else keymap if os.path.splitext(file)[1] == ".safetensors": @@ -468,6 +472,11 @@ class ToolkitNetworkMixin: for key in to_delete: del load_sd[key] + print(f"Missing keys: {to_delete}") + if len(to_delete) > 0 and self.is_v1: + print(" Attempting to load with forced keymap") + return self.load_weights(file, force_weight_mapping=True) + info = self.load_state_dict(load_sd, False) if len(extra_dict.keys()) == 0: extra_dict = None diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index e10239ca..1a53e4b7 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -528,6 +528,14 @@ class StableDiffusion: ) gen_config.negative_prompt_2 = gen_config.negative_prompt + if self.adapter is not None and isinstance(self.adapter, CustomAdapter): + self.adapter.trigger_pre_te( + tensors_0_1=validation_image, + is_training=False, + has_been_preprocessed=False, + quad_count=4 + ) + # encode the prompt ourselves so we can do fun stuff with embeddings conditional_embeds = self.encode_prompt(gen_config.prompt, gen_config.prompt_2, force_all=True)