mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 03:01:28 +00:00
Fixed ip adapter training. Works now
This commit is contained in:
@@ -549,11 +549,13 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
self.timer.stop('preprocess_batch')
|
||||
|
||||
is_reg = False
|
||||
with torch.no_grad():
|
||||
loss_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype)
|
||||
for idx, file_item in enumerate(batch.file_items):
|
||||
if file_item.is_reg:
|
||||
loss_multiplier[idx] = loss_multiplier[idx] * self.train_config.reg_weight
|
||||
is_reg = True
|
||||
|
||||
adapter_images = None
|
||||
sigmas = None
|
||||
@@ -764,11 +766,27 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
if has_adapter_img and self.adapter and isinstance(self.adapter, IPAdapter):
|
||||
with self.timer('encode_adapter'):
|
||||
if self.adapter and isinstance(self.adapter, IPAdapter):
|
||||
with self.timer('encode_adapter_embeds'):
|
||||
with torch.no_grad():
|
||||
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(adapter_images)
|
||||
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
|
||||
if has_adapter_img:
|
||||
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
|
||||
adapter_images.detach().to(self.device_torch, dtype=dtype))
|
||||
elif is_reg:
|
||||
# we will zero it out in the img embedder
|
||||
adapter_img = torch.zeros(
|
||||
(noisy_latents.shape[0], 3, 512, 512),
|
||||
device=self.device_torch, dtype=dtype
|
||||
)
|
||||
# drop will zero it out
|
||||
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
|
||||
adapter_img, drop=True
|
||||
)
|
||||
else:
|
||||
raise ValueError("Adapter images now must be loaded with dataloader or be a reg image")
|
||||
|
||||
with self.timer('encode_adapter'):
|
||||
conditional_embeds = self.adapter(conditional_embeds.detach(), conditional_clip_embeds.detach())
|
||||
|
||||
self.before_unet_predict()
|
||||
# do a prior pred if we have an unconditional image, we will swap out the giadance later
|
||||
|
||||
Reference in New Issue
Block a user