Bug fixes. Added IP adapter training for Pixart

This commit is contained in:
Jaret Burkett
2024-02-17 10:06:57 -07:00
parent 93b52932c1
commit 2478554c95
4 changed files with 278 additions and 49 deletions

View File

@@ -346,8 +346,9 @@ class SDTrainer(BaseSDTrainProcess):
print("Prior loss is nan")
prior_loss = None
else:
# prior_loss = prior_loss.mean([1, 2, 3])
loss = loss + prior_loss
prior_loss = prior_loss.mean([1, 2, 3])
# loss = loss + prior_loss
# loss = loss + prior_loss
# loss = loss + prior_loss
loss = loss.mean([1, 2, 3])
if prior_loss is not None:
@@ -731,6 +732,15 @@ class SDTrainer(BaseSDTrainProcess):
# self.network.multiplier = 0.0
self.sd.unet.eval()
if self.adapter is not None and isinstance(self.adapter, IPAdapter):
# we need to remove the image embeds from the prompt
embeds_to_use: PromptEmbeds = embeds_to_use.clone().detach()
end_pos = embeds_to_use.text_embeds.shape[1] - self.adapter_config.num_tokens
embeds_to_use.text_embeds = embeds_to_use.text_embeds[:, :end_pos, :]
if unconditional_embeds is not None:
unconditional_embeds = unconditional_embeds.clone().detach()
unconditional_embeds.text_embeds = unconditional_embeds.text_embeds[:, :end_pos]
if unconditional_embeds is not None:
unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach()