mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Bug fixes. Added IP adapter training for Pixart
This commit is contained in:
@@ -346,8 +346,9 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
print("Prior loss is nan")
|
||||
prior_loss = None
|
||||
else:
|
||||
# prior_loss = prior_loss.mean([1, 2, 3])
|
||||
loss = loss + prior_loss
|
||||
prior_loss = prior_loss.mean([1, 2, 3])
|
||||
# loss = loss + prior_loss
|
||||
# loss = loss + prior_loss
|
||||
# loss = loss + prior_loss
|
||||
loss = loss.mean([1, 2, 3])
|
||||
if prior_loss is not None:
|
||||
@@ -731,6 +732,15 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# self.network.multiplier = 0.0
|
||||
self.sd.unet.eval()
|
||||
|
||||
if self.adapter is not None and isinstance(self.adapter, IPAdapter):
|
||||
# we need to remove the image embeds from the prompt
|
||||
embeds_to_use: PromptEmbeds = embeds_to_use.clone().detach()
|
||||
end_pos = embeds_to_use.text_embeds.shape[1] - self.adapter_config.num_tokens
|
||||
embeds_to_use.text_embeds = embeds_to_use.text_embeds[:, :end_pos, :]
|
||||
if unconditional_embeds is not None:
|
||||
unconditional_embeds = unconditional_embeds.clone().detach()
|
||||
unconditional_embeds.text_embeds = unconditional_embeds.text_embeds[:, :end_pos]
|
||||
|
||||
if unconditional_embeds is not None:
|
||||
unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user