Bug fixes

This commit is contained in:
Jaret Burkett
2025-01-31 13:23:01 -07:00
parent 15a57bc89f
commit e6180d1e1d
3 changed files with 12 additions and 11 deletions

View File

@@ -123,11 +123,11 @@ class CustomAdapter(torch.nn.Module):
torch_dtype = get_torch_dtype(self.sd_ref().dtype)
if self.adapter_type == 'photo_maker':
sd = self.sd_ref()
embed_dim = sd.unet.config['cross_attention_dim']
embed_dim = sd.unet_unwrapped.config['cross_attention_dim']
self.fuse_module = FuseModule(embed_dim)
elif self.adapter_type == 'clip_fusion':
sd = self.sd_ref()
embed_dim = sd.unet.config['cross_attention_dim']
embed_dim = sd.unet_unwrapped.config['cross_attention_dim']
vision_tokens = ((self.vision_encoder.config.image_size // self.vision_encoder.config.patch_size) ** 2)
if self.config.image_encoder_arch == 'clip':
@@ -288,7 +288,7 @@ class CustomAdapter(torch.nn.Module):
self.vision_encoder = SAFEVisionModel(
in_channels=3,
num_tokens=self.config.safe_tokens,
num_vectors=sd.unet.config['cross_attention_dim'],
num_vectors=sd.unet_unwrapped.config['cross_attention_dim'],
reducer_channels=self.config.safe_reducer_channels,
channels=self.config.safe_channels,
downscale_factor=8