mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Small tweaks and fixes for specialized ip adapter training
This commit is contained in:
@@ -1211,7 +1211,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
clip_images,
|
||||
drop=True,
|
||||
is_training=True,
|
||||
has_been_preprocessed=True,
|
||||
has_been_preprocessed=False,
|
||||
quad_count=quad_count
|
||||
)
|
||||
if self.train_config.do_cfg:
|
||||
@@ -1222,7 +1222,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
).detach(),
|
||||
is_training=True,
|
||||
drop=True,
|
||||
has_been_preprocessed=True,
|
||||
has_been_preprocessed=False,
|
||||
quad_count=quad_count
|
||||
)
|
||||
elif has_clip_image:
|
||||
@@ -1230,14 +1230,14 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
clip_images.detach().to(self.device_torch, dtype=dtype),
|
||||
is_training=True,
|
||||
has_been_preprocessed=True,
|
||||
quad_count=quad_count
|
||||
quad_count=quad_count,
|
||||
# do cfg on clip embeds to normalize the embeddings for when doing cfg
|
||||
# cfg_embed_strength=3.0 if not self.train_config.do_cfg else None
|
||||
# cfg_embed_strength=3.0 if not self.train_config.do_cfg else None
|
||||
)
|
||||
if self.train_config.do_cfg:
|
||||
unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
|
||||
torch.zeros(
|
||||
(noisy_latents.shape[0], 3, image_size, image_size),
|
||||
device=self.device_torch, dtype=dtype
|
||||
).detach(),
|
||||
clip_images.detach().to(self.device_torch, dtype=dtype),
|
||||
is_training=True,
|
||||
drop=True,
|
||||
has_been_preprocessed=True,
|
||||
|
||||
Reference in New Issue
Block a user