mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Made peleminary arch for flux ip adapter training
This commit is contained in:
@@ -838,8 +838,8 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
# self.network.multiplier = 0.0
|
# self.network.multiplier = 0.0
|
||||||
self.sd.unet.eval()
|
self.sd.unet.eval()
|
||||||
|
|
||||||
if self.adapter is not None and isinstance(self.adapter, IPAdapter):
|
if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not self.sd.is_flux:
|
||||||
# we need to remove the image embeds from the prompt
|
# we need to remove the image embeds from the prompt except for flux
|
||||||
embeds_to_use: PromptEmbeds = embeds_to_use.clone().detach()
|
embeds_to_use: PromptEmbeds = embeds_to_use.clone().detach()
|
||||||
end_pos = embeds_to_use.text_embeds.shape[1] - self.adapter_config.num_tokens
|
end_pos = embeds_to_use.text_embeds.shape[1] - self.adapter_config.num_tokens
|
||||||
embeds_to_use.text_embeds = embeds_to_use.text_embeds[:, :end_pos, :]
|
embeds_to_use.text_embeds = embeds_to_use.text_embeds[:, :end_pos, :]
|
||||||
@@ -1268,7 +1268,7 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
if has_clip_image_embeds:
|
if has_clip_image_embeds:
|
||||||
# todo handle reg images better than this
|
# todo handle reg images better than this
|
||||||
if is_reg:
|
if is_reg:
|
||||||
# get unconditional image imbeds from cache
|
# get unconditional image embeds from cache
|
||||||
embeds = [
|
embeds = [
|
||||||
load_file(random.choice(batch.clip_image_embeds_unconditional)) for i in
|
load_file(random.choice(batch.clip_image_embeds_unconditional)) for i in
|
||||||
range(noisy_latents.shape[0])
|
range(noisy_latents.shape[0])
|
||||||
@@ -1353,10 +1353,20 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
|
|
||||||
with self.timer('encode_adapter'):
|
with self.timer('encode_adapter'):
|
||||||
self.adapter.train()
|
self.adapter.train()
|
||||||
conditional_embeds = self.adapter(conditional_embeds.detach(), conditional_clip_embeds)
|
conditional_embeds = self.adapter(
|
||||||
|
conditional_embeds.detach(),
|
||||||
|
conditional_clip_embeds,
|
||||||
|
is_unconditional=False
|
||||||
|
)
|
||||||
if self.train_config.do_cfg:
|
if self.train_config.do_cfg:
|
||||||
unconditional_embeds = self.adapter(unconditional_embeds.detach(),
|
unconditional_embeds = self.adapter(
|
||||||
unconditional_clip_embeds)
|
unconditional_embeds.detach(),
|
||||||
|
unconditional_clip_embeds,
|
||||||
|
is_unconditional=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# wipe out unconsitional
|
||||||
|
self.adapter.last_unconditional = None
|
||||||
|
|
||||||
if self.adapter and isinstance(self.adapter, ReferenceAdapter):
|
if self.adapter and isinstance(self.adapter, ReferenceAdapter):
|
||||||
# pass in our scheduler
|
# pass in our scheduler
|
||||||
|
|||||||
@@ -301,14 +301,13 @@ class CustomIPFluxAttnProcessor2_0(torch.nn.Module):
|
|||||||
if not is_active:
|
if not is_active:
|
||||||
ip_hidden_states = None
|
ip_hidden_states = None
|
||||||
else:
|
else:
|
||||||
# get encoder_hidden_states, ip_hidden_states
|
# get ip hidden states. Should be stored
|
||||||
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
ip_hidden_states = self.adapter_ref().last_conditional
|
||||||
encoder_hidden_states, ip_hidden_states = (
|
# add unconditional to front if it exists
|
||||||
encoder_hidden_states[:, :end_pos, :],
|
if ip_hidden_states.shape[0] * 2 == batch_size:
|
||||||
encoder_hidden_states[:, end_pos:, :],
|
if self.adapter_ref().last_unconditional is None:
|
||||||
)
|
raise ValueError("Unconditional is None but should not be")
|
||||||
# just strip it for now?
|
ip_hidden_states = torch.cat([self.adapter_ref().last_unconditional, ip_hidden_states], dim=0)
|
||||||
image_rotary_emb = image_rotary_emb[:, :, :-self.num_tokens, :, :, :]
|
|
||||||
|
|
||||||
# `context` projections.
|
# `context` projections.
|
||||||
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
||||||
@@ -411,6 +410,10 @@ class IPAdapter(torch.nn.Module):
|
|||||||
self.input_size = 224
|
self.input_size = 224
|
||||||
self.clip_noise_zero = True
|
self.clip_noise_zero = True
|
||||||
self.unconditional: torch.Tensor = None
|
self.unconditional: torch.Tensor = None
|
||||||
|
|
||||||
|
self.last_conditional: torch.Tensor = None
|
||||||
|
self.last_unconditional: torch.Tensor = None
|
||||||
|
|
||||||
self.additional_loss = None
|
self.additional_loss = None
|
||||||
if self.config.image_encoder_arch.startswith("clip"):
|
if self.config.image_encoder_arch.startswith("clip"):
|
||||||
try:
|
try:
|
||||||
@@ -574,12 +577,14 @@ class IPAdapter(torch.nn.Module):
|
|||||||
max_seq_len = int(
|
max_seq_len = int(
|
||||||
image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0])
|
image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0])
|
||||||
|
|
||||||
if is_pixart or is_flux:
|
if is_pixart:
|
||||||
# heads = 20
|
|
||||||
heads = 20
|
heads = 20
|
||||||
# dim = 4096
|
|
||||||
dim = 1280
|
dim = 1280
|
||||||
output_dim = 4096
|
output_dim = 4096
|
||||||
|
elif is_flux:
|
||||||
|
heads = 20
|
||||||
|
dim = 1280
|
||||||
|
output_dim = 3072
|
||||||
else:
|
else:
|
||||||
output_dim = sd.unet.config['cross_attention_dim']
|
output_dim = sd.unet.config['cross_attention_dim']
|
||||||
|
|
||||||
@@ -1136,10 +1141,18 @@ class IPAdapter(torch.nn.Module):
|
|||||||
return clip_image_embeds
|
return clip_image_embeds
|
||||||
|
|
||||||
# use drop for prompt dropout, or negatives
|
# use drop for prompt dropout, or negatives
|
||||||
def forward(self, embeddings: PromptEmbeds, clip_image_embeds: torch.Tensor) -> PromptEmbeds:
|
def forward(self, embeddings: PromptEmbeds, clip_image_embeds: torch.Tensor, is_unconditional=False) -> PromptEmbeds:
|
||||||
clip_image_embeds = clip_image_embeds.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
clip_image_embeds = clip_image_embeds.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||||
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
|
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
|
||||||
embeddings.text_embeds = torch.cat([embeddings.text_embeds, image_prompt_embeds], dim=1)
|
if self.sd_ref().is_flux:
|
||||||
|
# do not attach to text embeds for flux, we will save and grab them as it messes
|
||||||
|
# with the RoPE to have them in the same tensor
|
||||||
|
if is_unconditional:
|
||||||
|
self.last_unconditional = image_prompt_embeds
|
||||||
|
else:
|
||||||
|
self.last_conditional = image_prompt_embeds
|
||||||
|
else:
|
||||||
|
embeddings.text_embeds = torch.cat([embeddings.text_embeds, image_prompt_embeds], dim=1)
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
def train(self: T, mode: bool = True) -> T:
|
def train(self: T, mode: bool = True) -> T:
|
||||||
|
|||||||
@@ -1124,8 +1124,8 @@ class StableDiffusion:
|
|||||||
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image)
|
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image)
|
||||||
unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image,
|
unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image,
|
||||||
True)
|
True)
|
||||||
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
|
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds, is_unconditional=False)
|
||||||
unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds)
|
unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds, is_unconditional=True)
|
||||||
|
|
||||||
if self.adapter is not None and isinstance(self.adapter,
|
if self.adapter is not None and isinstance(self.adapter,
|
||||||
CustomAdapter) and validation_image is not None:
|
CustomAdapter) and validation_image is not None:
|
||||||
|
|||||||
Reference in New Issue
Block a user