Made peleminary arch for flux ip adapter training

This commit is contained in:
Jaret Burkett
2024-08-28 08:55:39 -06:00
parent 3843e0d148
commit 60232def91
3 changed files with 44 additions and 21 deletions

View File

@@ -838,8 +838,8 @@ class SDTrainer(BaseSDTrainProcess):
# self.network.multiplier = 0.0
self.sd.unet.eval()
if self.adapter is not None and isinstance(self.adapter, IPAdapter):
# we need to remove the image embeds from the prompt
if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not self.sd.is_flux:
# we need to remove the image embeds from the prompt except for flux
embeds_to_use: PromptEmbeds = embeds_to_use.clone().detach()
end_pos = embeds_to_use.text_embeds.shape[1] - self.adapter_config.num_tokens
embeds_to_use.text_embeds = embeds_to_use.text_embeds[:, :end_pos, :]
@@ -1268,7 +1268,7 @@ class SDTrainer(BaseSDTrainProcess):
if has_clip_image_embeds:
# todo handle reg images better than this
if is_reg:
# get unconditional image imbeds from cache
# get unconditional image embeds from cache
embeds = [
load_file(random.choice(batch.clip_image_embeds_unconditional)) for i in
range(noisy_latents.shape[0])
@@ -1353,10 +1353,20 @@ class SDTrainer(BaseSDTrainProcess):
with self.timer('encode_adapter'):
self.adapter.train()
conditional_embeds = self.adapter(conditional_embeds.detach(), conditional_clip_embeds)
conditional_embeds = self.adapter(
conditional_embeds.detach(),
conditional_clip_embeds,
is_unconditional=False
)
if self.train_config.do_cfg:
unconditional_embeds = self.adapter(unconditional_embeds.detach(),
unconditional_clip_embeds)
unconditional_embeds = self.adapter(
unconditional_embeds.detach(),
unconditional_clip_embeds,
is_unconditional=True
)
else:
# wipe out unconsitional
self.adapter.last_unconditional = None
if self.adapter and isinstance(self.adapter, ReferenceAdapter):
# pass in our scheduler