Fixed ip adapter training. Works now

This commit is contained in:
Jaret Burkett
2023-12-17 08:22:59 -07:00
parent 13d32423f6
commit b653906715
7 changed files with 102 additions and 40 deletions

View File

@@ -549,11 +549,13 @@ class SDTrainer(BaseSDTrainProcess):
self.timer.stop('preprocess_batch')
is_reg = False
with torch.no_grad():
loss_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype)
for idx, file_item in enumerate(batch.file_items):
if file_item.is_reg:
loss_multiplier[idx] = loss_multiplier[idx] * self.train_config.reg_weight
is_reg = True
adapter_images = None
sigmas = None
@@ -764,11 +766,27 @@ class SDTrainer(BaseSDTrainProcess):
batch=batch,
)
if has_adapter_img and self.adapter and isinstance(self.adapter, IPAdapter):
with self.timer('encode_adapter'):
if self.adapter and isinstance(self.adapter, IPAdapter):
with self.timer('encode_adapter_embeds'):
with torch.no_grad():
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(adapter_images)
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
if has_adapter_img:
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
adapter_images.detach().to(self.device_torch, dtype=dtype))
elif is_reg:
# we will zero it out in the img embedder
adapter_img = torch.zeros(
(noisy_latents.shape[0], 3, 512, 512),
device=self.device_torch, dtype=dtype
)
# drop will zero it out
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
adapter_img, drop=True
)
else:
raise ValueError("Adapter images now must be loaded with dataloader or be a reg image")
with self.timer('encode_adapter'):
conditional_embeds = self.adapter(conditional_embeds.detach(), conditional_clip_embeds.detach())
self.before_unet_predict()
# do a prior pred if we have an unconditional image, we will swap out the giadance later