mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 02:01:29 +00:00
Bug fixes
This commit is contained in:
@@ -123,11 +123,11 @@ class CustomAdapter(torch.nn.Module):
|
||||
torch_dtype = get_torch_dtype(self.sd_ref().dtype)
|
||||
if self.adapter_type == 'photo_maker':
|
||||
sd = self.sd_ref()
|
||||
embed_dim = sd.unet.config['cross_attention_dim']
|
||||
embed_dim = sd.unet_unwrapped.config['cross_attention_dim']
|
||||
self.fuse_module = FuseModule(embed_dim)
|
||||
elif self.adapter_type == 'clip_fusion':
|
||||
sd = self.sd_ref()
|
||||
embed_dim = sd.unet.config['cross_attention_dim']
|
||||
embed_dim = sd.unet_unwrapped.config['cross_attention_dim']
|
||||
|
||||
vision_tokens = ((self.vision_encoder.config.image_size // self.vision_encoder.config.patch_size) ** 2)
|
||||
if self.config.image_encoder_arch == 'clip':
|
||||
@@ -288,7 +288,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
self.vision_encoder = SAFEVisionModel(
|
||||
in_channels=3,
|
||||
num_tokens=self.config.safe_tokens,
|
||||
num_vectors=sd.unet.config['cross_attention_dim'],
|
||||
num_vectors=sd.unet_unwrapped.config['cross_attention_dim'],
|
||||
reducer_channels=self.config.safe_reducer_channels,
|
||||
channels=self.config.safe_channels,
|
||||
downscale_factor=8
|
||||
|
||||
Reference in New Issue
Block a user