Made peleminary arch for flux ip adapter training

This commit is contained in:
Jaret Burkett
2024-08-28 08:55:39 -06:00
parent 3843e0d148
commit 60232def91
3 changed files with 44 additions and 21 deletions

View File

@@ -838,8 +838,8 @@ class SDTrainer(BaseSDTrainProcess):
# self.network.multiplier = 0.0
self.sd.unet.eval()
if self.adapter is not None and isinstance(self.adapter, IPAdapter):
# we need to remove the image embeds from the prompt
if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not self.sd.is_flux:
# we need to remove the image embeds from the prompt except for flux
embeds_to_use: PromptEmbeds = embeds_to_use.clone().detach()
end_pos = embeds_to_use.text_embeds.shape[1] - self.adapter_config.num_tokens
embeds_to_use.text_embeds = embeds_to_use.text_embeds[:, :end_pos, :]
@@ -1268,7 +1268,7 @@ class SDTrainer(BaseSDTrainProcess):
if has_clip_image_embeds:
# todo handle reg images better than this
if is_reg:
# get unconditional image imbeds from cache
# get unconditional image embeds from cache
embeds = [
load_file(random.choice(batch.clip_image_embeds_unconditional)) for i in
range(noisy_latents.shape[0])
@@ -1353,10 +1353,20 @@ class SDTrainer(BaseSDTrainProcess):
with self.timer('encode_adapter'):
self.adapter.train()
conditional_embeds = self.adapter(conditional_embeds.detach(), conditional_clip_embeds)
conditional_embeds = self.adapter(
conditional_embeds.detach(),
conditional_clip_embeds,
is_unconditional=False
)
if self.train_config.do_cfg:
unconditional_embeds = self.adapter(unconditional_embeds.detach(),
unconditional_clip_embeds)
unconditional_embeds = self.adapter(
unconditional_embeds.detach(),
unconditional_clip_embeds,
is_unconditional=True
)
else:
# wipe out unconsitional
self.adapter.last_unconditional = None
if self.adapter and isinstance(self.adapter, ReferenceAdapter):
# pass in our scheduler

View File

@@ -301,14 +301,13 @@ class CustomIPFluxAttnProcessor2_0(torch.nn.Module):
if not is_active:
ip_hidden_states = None
else:
# get encoder_hidden_states, ip_hidden_states
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, :end_pos, :],
encoder_hidden_states[:, end_pos:, :],
)
# just strip it for now?
image_rotary_emb = image_rotary_emb[:, :, :-self.num_tokens, :, :, :]
# get ip hidden states. Should be stored
ip_hidden_states = self.adapter_ref().last_conditional
# add unconditional to front if it exists
if ip_hidden_states.shape[0] * 2 == batch_size:
if self.adapter_ref().last_unconditional is None:
raise ValueError("Unconditional is None but should not be")
ip_hidden_states = torch.cat([self.adapter_ref().last_unconditional, ip_hidden_states], dim=0)
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
@@ -411,6 +410,10 @@ class IPAdapter(torch.nn.Module):
self.input_size = 224
self.clip_noise_zero = True
self.unconditional: torch.Tensor = None
self.last_conditional: torch.Tensor = None
self.last_unconditional: torch.Tensor = None
self.additional_loss = None
if self.config.image_encoder_arch.startswith("clip"):
try:
@@ -574,12 +577,14 @@ class IPAdapter(torch.nn.Module):
max_seq_len = int(
image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0])
if is_pixart or is_flux:
# heads = 20
if is_pixart:
heads = 20
# dim = 4096
dim = 1280
output_dim = 4096
elif is_flux:
heads = 20
dim = 1280
output_dim = 3072
else:
output_dim = sd.unet.config['cross_attention_dim']
@@ -1136,10 +1141,18 @@ class IPAdapter(torch.nn.Module):
return clip_image_embeds
# use drop for prompt dropout, or negatives
def forward(self, embeddings: PromptEmbeds, clip_image_embeds: torch.Tensor) -> PromptEmbeds:
def forward(self, embeddings: PromptEmbeds, clip_image_embeds: torch.Tensor, is_unconditional=False) -> PromptEmbeds:
clip_image_embeds = clip_image_embeds.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
embeddings.text_embeds = torch.cat([embeddings.text_embeds, image_prompt_embeds], dim=1)
if self.sd_ref().is_flux:
# do not attach to text embeds for flux, we will save and grab them as it messes
# with the RoPE to have them in the same tensor
if is_unconditional:
self.last_unconditional = image_prompt_embeds
else:
self.last_conditional = image_prompt_embeds
else:
embeddings.text_embeds = torch.cat([embeddings.text_embeds, image_prompt_embeds], dim=1)
return embeddings
def train(self: T, mode: bool = True) -> T:

View File

@@ -1124,8 +1124,8 @@ class StableDiffusion:
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image)
unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image,
True)
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds)
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds, is_unconditional=False)
unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds, is_unconditional=True)
if self.adapter is not None and isinstance(self.adapter,
CustomAdapter) and validation_image is not None: