Switched ip adapter dataloader to clip_image paths so the control paths can be used for training assistant adapters while training ip adapters

This commit is contained in:
Jaret Burkett
2023-12-20 10:32:24 -07:00
parent dfb64b5957
commit 0f597f453e
5 changed files with 94 additions and 12 deletions

View File

@@ -567,6 +567,10 @@ class SDTrainer(BaseSDTrainProcess):
network_weight_list = network_weight_list + network_weight_list
has_adapter_img = batch.control_tensor is not None
has_clip_image = batch.clip_image_tensor is not None
if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not has_clip_image and has_adapter_img:
raise ValueError("IPAdapter control image is now 'clip_image_path' instead of 'control_path'. Please update your dataset config ")
match_adapter_assist = False
@@ -604,10 +608,14 @@ class SDTrainer(BaseSDTrainProcess):
adapter_images = adapter_images[:, :in_channels, :, :]
else:
raise NotImplementedError("Adapter images now must be loaded with dataloader")
# not 100% sure what this does. But they do it here
# https://github.com/huggingface/diffusers/blob/38a664a3d61e27ab18cd698231422b3c38d6eebf/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1170
# sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype)
# noisy_latents = noisy_latents / ((sigmas ** 2 + 1) ** 0.5)
clip_images = None
if has_clip_image:
with self.timer('get_clip_images'):
# todo move this to data loader
if batch.clip_image_tensor is not None:
clip_images = batch.clip_image_tensor.to(self.device_torch, dtype=dtype).detach()
mask_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype)
if batch.mask_tensor is not None:
@@ -697,6 +705,10 @@ class SDTrainer(BaseSDTrainProcess):
adapter_images_list = torch.chunk(adapter_images, batch_size, dim=0)
else:
adapter_images_list = [None for _ in range(batch_size)]
if clip_images is not None:
clip_images_list = torch.chunk(clip_images, batch_size, dim=0)
else:
clip_images_list = [None for _ in range(batch_size)]
mask_multiplier_list = torch.chunk(mask_multiplier, batch_size, dim=0)
if prompts_2 is None:
prompt_2_list = [None for _ in range(batch_size)]
@@ -710,19 +722,21 @@ class SDTrainer(BaseSDTrainProcess):
conditioned_prompts_list = [prompts_1]
imgs_list = [imgs]
adapter_images_list = [adapter_images]
clip_images_list = [clip_images]
mask_multiplier_list = [mask_multiplier]
if prompts_2 is None:
prompt_2_list = [None]
else:
prompt_2_list = [prompts_2]
for noisy_latents, noise, timesteps, conditioned_prompts, imgs, adapter_images, mask_multiplier, prompt_2 in zip(
for noisy_latents, noise, timesteps, conditioned_prompts, imgs, adapter_images, clip_images, mask_multiplier, prompt_2 in zip(
noisy_latents_list,
noise_list,
timesteps_list,
conditioned_prompts_list,
imgs_list,
adapter_images_list,
clip_images_list,
mask_multiplier_list,
prompt_2_list
):
@@ -766,7 +780,7 @@ class SDTrainer(BaseSDTrainProcess):
if has_adapter_img and (
(self.adapter and isinstance(self.adapter, T2IAdapter)) or self.assistant_adapter):
with torch.set_grad_enabled(self.adapter is not None):
adapter = self.adapter if self.adapter else self.assistant_adapter
adapter = self.assistant_adapter if self.assistant_adapter is not None else self.adapter
adapter_multiplier = get_adapter_multiplier()
with self.timer('encode_adapter'):
down_block_additional_residuals = adapter(adapter_images)
@@ -787,20 +801,20 @@ class SDTrainer(BaseSDTrainProcess):
if self.adapter and isinstance(self.adapter, IPAdapter):
with self.timer('encode_adapter_embeds'):
if has_adapter_img:
if has_clip_image:
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
adapter_images.detach().to(self.device_torch, dtype=dtype),
clip_images.detach().to(self.device_torch, dtype=dtype),
is_training=True
)
elif is_reg:
# we will zero it out in the img embedder
adapter_img = torch.zeros(
clip_images = torch.zeros(
(noisy_latents.shape[0], 3, 512, 512),
device=self.device_torch, dtype=dtype
).detach()
# drop will zero it out
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
adapter_img,
clip_images,
drop=True,
is_training=True
)