Small tweaks and fixes for specialized ip adapter training

This commit is contained in:
Jaret Burkett
2024-03-26 11:35:26 -06:00
parent 9c1cc9641e
commit 427847ac4c
6 changed files with 117 additions and 11 deletions

View File

@@ -1211,7 +1211,7 @@ class SDTrainer(BaseSDTrainProcess):
clip_images,
drop=True,
is_training=True,
has_been_preprocessed=True,
has_been_preprocessed=False,
quad_count=quad_count
)
if self.train_config.do_cfg:
@@ -1222,7 +1222,7 @@ class SDTrainer(BaseSDTrainProcess):
).detach(),
is_training=True,
drop=True,
has_been_preprocessed=True,
has_been_preprocessed=False,
quad_count=quad_count
)
elif has_clip_image:
@@ -1230,14 +1230,14 @@ class SDTrainer(BaseSDTrainProcess):
clip_images.detach().to(self.device_torch, dtype=dtype),
is_training=True,
has_been_preprocessed=True,
quad_count=quad_count
quad_count=quad_count,
# do cfg on clip embeds to normalize the embeddings for when doing cfg
# cfg_embed_strength=3.0 if not self.train_config.do_cfg else None
# cfg_embed_strength=3.0 if not self.train_config.do_cfg else None
)
if self.train_config.do_cfg:
unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
torch.zeros(
(noisy_latents.shape[0], 3, image_size, image_size),
device=self.device_torch, dtype=dtype
).detach(),
clip_images.detach().to(self.device_torch, dtype=dtype),
is_training=True,
drop=True,
has_been_preprocessed=True,