mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Fixed some new bugs i added. woops
This commit is contained in:
@@ -163,7 +163,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
loss = loss * mask_multiplier
|
||||
|
||||
prior_loss = None
|
||||
if self.train_config.inverted_mask_prior and prior_pred is not None:
|
||||
if self.train_config.inverted_mask_prior and prior_pred is not None and prior_mask_multiplier is not None:
|
||||
# to a loss to unmasked areas of the prior for unmasked regularization
|
||||
prior_loss = torch.nn.functional.mse_loss(
|
||||
prior_pred.float(),
|
||||
|
||||
@@ -184,8 +184,8 @@ class IPAdapter(torch.nn.Module):
|
||||
self.clip_image_processor = SAFEImageProcessor()
|
||||
self.image_encoder = SAFEVisionModel(
|
||||
in_channels=3,
|
||||
num_tokens=self.config.num_tokens if self.config.adapter_type == 'ip+' else 1,
|
||||
num_vectors=sd.unet.config['cross_attention_dim'] if self.config.adapter_type == 'ip+' else self.config.safe_channels,
|
||||
num_tokens=8,
|
||||
num_vectors=sd.unet.config['cross_attention_dim'],
|
||||
reducer_channels=self.config.safe_reducer_channels,
|
||||
channels=self.config.safe_channels,
|
||||
downscale_factor=8
|
||||
|
||||
Reference in New Issue
Block a user