diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index e82842d1..732021fe 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -163,7 +163,7 @@ class SDTrainer(BaseSDTrainProcess): loss = loss * mask_multiplier prior_loss = None - if self.train_config.inverted_mask_prior and prior_pred is not None: + if self.train_config.inverted_mask_prior and prior_pred is not None and prior_mask_multiplier is not None: # to a loss to unmasked areas of the prior for unmasked regularization prior_loss = torch.nn.functional.mse_loss( prior_pred.float(), diff --git a/toolkit/ip_adapter.py b/toolkit/ip_adapter.py index 2f49fa21..efa9ac07 100644 --- a/toolkit/ip_adapter.py +++ b/toolkit/ip_adapter.py @@ -184,8 +184,8 @@ class IPAdapter(torch.nn.Module): self.clip_image_processor = SAFEImageProcessor() self.image_encoder = SAFEVisionModel( in_channels=3, - num_tokens=self.config.num_tokens if self.config.adapter_type == 'ip+' else 1, - num_vectors=sd.unet.config['cross_attention_dim'] if self.config.adapter_type == 'ip+' else self.config.safe_channels, + num_tokens=8, + num_vectors=sd.unet.config['cross_attention_dim'], reducer_channels=self.config.safe_reducer_channels, channels=self.config.safe_channels, downscale_factor=8