mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-13 14:39:50 +00:00
Added support to disable single transformers in vision direct adapter
This commit is contained in:
@@ -203,6 +203,8 @@ class AdapterConfig:
|
||||
self.ilora_mid: bool = kwargs.get('ilora_mid', True)
|
||||
self.ilora_up: bool = kwargs.get('ilora_up', True)
|
||||
|
||||
self.flux_only_double: bool = kwargs.get('flux_only_double', False)
|
||||
|
||||
|
||||
class EmbeddingConfig:
|
||||
def __init__(self, **kwargs):
|
||||
|
||||
@@ -616,19 +616,28 @@ class VisionDirectAdapter(torch.nn.Module):
|
||||
for i, module in transformer.transformer_blocks.named_children():
|
||||
module.attn.processor = attn_procs[f"transformer_blocks.{i}.attn"]
|
||||
|
||||
# do single blocks too even though they dont have cross attn
|
||||
for i, module in transformer.single_transformer_blocks.named_children():
|
||||
module.attn.processor = attn_procs[f"single_transformer_blocks.{i}.attn"]
|
||||
if not self.config.flux_only_double:
|
||||
# do single blocks too even though they dont have cross attn
|
||||
for i, module in transformer.single_transformer_blocks.named_children():
|
||||
module.attn.processor = attn_procs[f"single_transformer_blocks.{i}.attn"]
|
||||
|
||||
self.adapter_modules = torch.nn.ModuleList(
|
||||
[
|
||||
transformer.transformer_blocks[i].attn.processor for i in
|
||||
range(len(transformer.transformer_blocks))
|
||||
] + [
|
||||
transformer.single_transformer_blocks[i].attn.processor for i in
|
||||
range(len(transformer.single_transformer_blocks))
|
||||
]
|
||||
)
|
||||
if not self.config.flux_only_double:
|
||||
self.adapter_modules = torch.nn.ModuleList(
|
||||
[
|
||||
transformer.transformer_blocks[i].attn.processor for i in
|
||||
range(len(transformer.transformer_blocks))
|
||||
] + [
|
||||
transformer.single_transformer_blocks[i].attn.processor for i in
|
||||
range(len(transformer.single_transformer_blocks))
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.adapter_modules = torch.nn.ModuleList(
|
||||
[
|
||||
transformer.transformer_blocks[i].attn.processor for i in
|
||||
range(len(transformer.transformer_blocks))
|
||||
]
|
||||
)
|
||||
else:
|
||||
sd.unet.set_attn_processor(attn_procs)
|
||||
self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values())
|
||||
|
||||
Reference in New Issue
Block a user