mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Bug fixes and minor features
This commit is contained in:
@@ -487,7 +487,7 @@ class IPAdapter(torch.nn.Module):
|
||||
attn_processor_names = []
|
||||
|
||||
for name in attn_processor_keys:
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") else \
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") or name.endswith("attn1") else \
|
||||
sd.unet.config['cross_attention_dim']
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = sd.unet.config['block_out_channels'][-1]
|
||||
@@ -540,9 +540,6 @@ class IPAdapter(torch.nn.Module):
|
||||
module.attn2.processor = attn_procs[f"transformer_blocks.{i}.attn2"]
|
||||
self.adapter_modules = torch.nn.ModuleList(
|
||||
[
|
||||
transformer.transformer_blocks[i].attn1.processor for i in
|
||||
range(len(transformer.transformer_blocks))
|
||||
] + [
|
||||
transformer.transformer_blocks[i].attn2.processor for i in
|
||||
range(len(transformer.transformer_blocks))
|
||||
])
|
||||
|
||||
Reference in New Issue
Block a user