mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Made peleminary arch for flux ip adapter training
This commit is contained in:
@@ -838,8 +838,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# self.network.multiplier = 0.0
|
||||
self.sd.unet.eval()
|
||||
|
||||
if self.adapter is not None and isinstance(self.adapter, IPAdapter):
|
||||
# we need to remove the image embeds from the prompt
|
||||
if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not self.sd.is_flux:
|
||||
# we need to remove the image embeds from the prompt except for flux
|
||||
embeds_to_use: PromptEmbeds = embeds_to_use.clone().detach()
|
||||
end_pos = embeds_to_use.text_embeds.shape[1] - self.adapter_config.num_tokens
|
||||
embeds_to_use.text_embeds = embeds_to_use.text_embeds[:, :end_pos, :]
|
||||
@@ -1268,7 +1268,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if has_clip_image_embeds:
|
||||
# todo handle reg images better than this
|
||||
if is_reg:
|
||||
# get unconditional image imbeds from cache
|
||||
# get unconditional image embeds from cache
|
||||
embeds = [
|
||||
load_file(random.choice(batch.clip_image_embeds_unconditional)) for i in
|
||||
range(noisy_latents.shape[0])
|
||||
@@ -1353,10 +1353,20 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
with self.timer('encode_adapter'):
|
||||
self.adapter.train()
|
||||
conditional_embeds = self.adapter(conditional_embeds.detach(), conditional_clip_embeds)
|
||||
conditional_embeds = self.adapter(
|
||||
conditional_embeds.detach(),
|
||||
conditional_clip_embeds,
|
||||
is_unconditional=False
|
||||
)
|
||||
if self.train_config.do_cfg:
|
||||
unconditional_embeds = self.adapter(unconditional_embeds.detach(),
|
||||
unconditional_clip_embeds)
|
||||
unconditional_embeds = self.adapter(
|
||||
unconditional_embeds.detach(),
|
||||
unconditional_clip_embeds,
|
||||
is_unconditional=True
|
||||
)
|
||||
else:
|
||||
# wipe out unconsitional
|
||||
self.adapter.last_unconditional = None
|
||||
|
||||
if self.adapter and isinstance(self.adapter, ReferenceAdapter):
|
||||
# pass in our scheduler
|
||||
|
||||
Reference in New Issue
Block a user