Many bug fixes. Ip adapter bug fixes. Added noise to unconditional, it works better. added an ilora adapter for 1 shotting LoRAs

This commit is contained in:
Jaret Burkett
2024-01-28 08:20:03 -07:00
parent f17ad8d794
commit 92b9c71d44
10 changed files with 352 additions and 56 deletions

View File

@@ -286,9 +286,10 @@ class SDTrainer(BaseSDTrainProcess):
)
prior_loss = prior_loss * prior_mask_multiplier * self.train_config.inverted_mask_prior_multiplier
if torch.isnan(prior_loss).any():
raise ValueError("Prior loss is nan")
prior_loss = prior_loss.mean([1, 2, 3])
print("Prior loss is nan")
prior_loss = None
else:
prior_loss = prior_loss.mean([1, 2, 3])
# loss = loss + prior_loss
loss = loss.mean([1, 2, 3])
if prior_loss is not None:
@@ -992,6 +993,8 @@ class SDTrainer(BaseSDTrainProcess):
if self.adapter and isinstance(self.adapter, IPAdapter):
with self.timer('encode_adapter_embeds'):
# number of images to do if doing a quad image
quad_count = random.randint(1, 4)
image_size = self.adapter.input_size
if is_reg:
# we will zero it out in the img embedder
@@ -1004,7 +1007,8 @@ class SDTrainer(BaseSDTrainProcess):
clip_images,
drop=True,
is_training=True,
has_been_preprocessed=True
has_been_preprocessed=True,
quad_count=quad_count
)
if self.train_config.do_cfg:
unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
@@ -1014,13 +1018,15 @@ class SDTrainer(BaseSDTrainProcess):
).detach(),
is_training=True,
drop=True,
has_been_preprocessed=True
has_been_preprocessed=True,
quad_count=quad_count
)
elif has_clip_image:
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
clip_images.detach().to(self.device_torch, dtype=dtype),
is_training=True,
has_been_preprocessed=True
has_been_preprocessed=True,
quad_count=quad_count
)
if self.train_config.do_cfg:
unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
@@ -1030,7 +1036,8 @@ class SDTrainer(BaseSDTrainProcess):
).detach(),
is_training=True,
drop=True,
has_been_preprocessed=True
has_been_preprocessed=True,
quad_count=quad_count
)
else:
raise ValueError("Adapter images now must be loaded with dataloader or be a reg image")
@@ -1152,7 +1159,8 @@ class SDTrainer(BaseSDTrainProcess):
)
# check if nan
if torch.isnan(loss):
raise ValueError("loss is nan")
print("loss is nan")
loss = torch.zeros_like(loss).requires_grad_(True)
with self.timer('backward'):
# todo we have multiplier seperated. works for now as res are not in same batch, but need to change