diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 86513c32..07e8240c 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -838,8 +838,8 @@ class SDTrainer(BaseSDTrainProcess): # self.network.multiplier = 0.0 self.sd.unet.eval() - if self.adapter is not None and isinstance(self.adapter, IPAdapter): - # we need to remove the image embeds from the prompt + if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not self.sd.is_flux: + # we need to remove the image embeds from the prompt except for flux embeds_to_use: PromptEmbeds = embeds_to_use.clone().detach() end_pos = embeds_to_use.text_embeds.shape[1] - self.adapter_config.num_tokens embeds_to_use.text_embeds = embeds_to_use.text_embeds[:, :end_pos, :] @@ -1268,7 +1268,7 @@ class SDTrainer(BaseSDTrainProcess): if has_clip_image_embeds: # todo handle reg images better than this if is_reg: - # get unconditional image imbeds from cache + # get unconditional image embeds from cache embeds = [ load_file(random.choice(batch.clip_image_embeds_unconditional)) for i in range(noisy_latents.shape[0]) @@ -1353,10 +1353,20 @@ class SDTrainer(BaseSDTrainProcess): with self.timer('encode_adapter'): self.adapter.train() - conditional_embeds = self.adapter(conditional_embeds.detach(), conditional_clip_embeds) + conditional_embeds = self.adapter( + conditional_embeds.detach(), + conditional_clip_embeds, + is_unconditional=False + ) if self.train_config.do_cfg: - unconditional_embeds = self.adapter(unconditional_embeds.detach(), - unconditional_clip_embeds) + unconditional_embeds = self.adapter( + unconditional_embeds.detach(), + unconditional_clip_embeds, + is_unconditional=True + ) + else: + # wipe out unconsitional + self.adapter.last_unconditional = None if self.adapter and isinstance(self.adapter, ReferenceAdapter): # pass in our scheduler diff --git a/toolkit/ip_adapter.py b/toolkit/ip_adapter.py index 8d735b78..fea1e81f 100644 --- a/toolkit/ip_adapter.py +++ b/toolkit/ip_adapter.py @@ -301,14 +301,13 @@ class CustomIPFluxAttnProcessor2_0(torch.nn.Module): if not is_active: ip_hidden_states = None else: - # get encoder_hidden_states, ip_hidden_states - end_pos = encoder_hidden_states.shape[1] - self.num_tokens - encoder_hidden_states, ip_hidden_states = ( - encoder_hidden_states[:, :end_pos, :], - encoder_hidden_states[:, end_pos:, :], - ) - # just strip it for now? - image_rotary_emb = image_rotary_emb[:, :, :-self.num_tokens, :, :, :] + # get ip hidden states. Should be stored + ip_hidden_states = self.adapter_ref().last_conditional + # add unconditional to front if it exists + if ip_hidden_states.shape[0] * 2 == batch_size: + if self.adapter_ref().last_unconditional is None: + raise ValueError("Unconditional is None but should not be") + ip_hidden_states = torch.cat([self.adapter_ref().last_unconditional, ip_hidden_states], dim=0) # `context` projections. encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) @@ -411,6 +410,10 @@ class IPAdapter(torch.nn.Module): self.input_size = 224 self.clip_noise_zero = True self.unconditional: torch.Tensor = None + + self.last_conditional: torch.Tensor = None + self.last_unconditional: torch.Tensor = None + self.additional_loss = None if self.config.image_encoder_arch.startswith("clip"): try: @@ -574,12 +577,14 @@ class IPAdapter(torch.nn.Module): max_seq_len = int( image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0]) - if is_pixart or is_flux: - # heads = 20 + if is_pixart: heads = 20 - # dim = 4096 dim = 1280 output_dim = 4096 + elif is_flux: + heads = 20 + dim = 1280 + output_dim = 3072 else: output_dim = sd.unet.config['cross_attention_dim'] @@ -1136,10 +1141,18 @@ class IPAdapter(torch.nn.Module): return clip_image_embeds # use drop for prompt dropout, or negatives - def forward(self, embeddings: PromptEmbeds, clip_image_embeds: torch.Tensor) -> PromptEmbeds: + def forward(self, embeddings: PromptEmbeds, clip_image_embeds: torch.Tensor, is_unconditional=False) -> PromptEmbeds: clip_image_embeds = clip_image_embeds.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) image_prompt_embeds = self.image_proj_model(clip_image_embeds) - embeddings.text_embeds = torch.cat([embeddings.text_embeds, image_prompt_embeds], dim=1) + if self.sd_ref().is_flux: + # do not attach to text embeds for flux, we will save and grab them as it messes + # with the RoPE to have them in the same tensor + if is_unconditional: + self.last_unconditional = image_prompt_embeds + else: + self.last_conditional = image_prompt_embeds + else: + embeddings.text_embeds = torch.cat([embeddings.text_embeds, image_prompt_embeds], dim=1) return embeddings def train(self: T, mode: bool = True) -> T: diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 9e27a1ed..db1f8e81 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -1124,8 +1124,8 @@ class StableDiffusion: conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image) unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image, True) - conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds) - unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds) + conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds, is_unconditional=False) + unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds, is_unconditional=True) if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and validation_image is not None: