Fixed some new bugs i added. woops

This commit is contained in:
Jaret Burkett
2023-12-28 14:03:42 -07:00
parent eeee4a1620
commit 0892dec4a5
2 changed files with 3 additions and 3 deletions

View File

@@ -163,7 +163,7 @@ class SDTrainer(BaseSDTrainProcess):
loss = loss * mask_multiplier
prior_loss = None
if self.train_config.inverted_mask_prior and prior_pred is not None:
if self.train_config.inverted_mask_prior and prior_pred is not None and prior_mask_multiplier is not None:
# to a loss to unmasked areas of the prior for unmasked regularization
prior_loss = torch.nn.functional.mse_loss(
prior_pred.float(),

View File

@@ -184,8 +184,8 @@ class IPAdapter(torch.nn.Module):
self.clip_image_processor = SAFEImageProcessor()
self.image_encoder = SAFEVisionModel(
in_channels=3,
num_tokens=self.config.num_tokens if self.config.adapter_type == 'ip+' else 1,
num_vectors=sd.unet.config['cross_attention_dim'] if self.config.adapter_type == 'ip+' else self.config.safe_channels,
num_tokens=8,
num_vectors=sd.unet.config['cross_attention_dim'],
reducer_channels=self.config.safe_reducer_channels,
channels=self.config.safe_channels,
downscale_factor=8