Added support to disable single transformers in vision direct adapter

This commit is contained in:
Jaret Burkett
2024-09-11 08:54:51 -06:00
parent fc34a69bec
commit 951e223481
2 changed files with 23 additions and 12 deletions

View File

@@ -203,6 +203,8 @@ class AdapterConfig:
self.ilora_mid: bool = kwargs.get('ilora_mid', True)
self.ilora_up: bool = kwargs.get('ilora_up', True)
self.flux_only_double: bool = kwargs.get('flux_only_double', False)
class EmbeddingConfig:
def __init__(self, **kwargs):

View File

@@ -616,19 +616,28 @@ class VisionDirectAdapter(torch.nn.Module):
for i, module in transformer.transformer_blocks.named_children():
module.attn.processor = attn_procs[f"transformer_blocks.{i}.attn"]
# do single blocks too even though they dont have cross attn
for i, module in transformer.single_transformer_blocks.named_children():
module.attn.processor = attn_procs[f"single_transformer_blocks.{i}.attn"]
if not self.config.flux_only_double:
# do single blocks too even though they dont have cross attn
for i, module in transformer.single_transformer_blocks.named_children():
module.attn.processor = attn_procs[f"single_transformer_blocks.{i}.attn"]
self.adapter_modules = torch.nn.ModuleList(
[
transformer.transformer_blocks[i].attn.processor for i in
range(len(transformer.transformer_blocks))
] + [
transformer.single_transformer_blocks[i].attn.processor for i in
range(len(transformer.single_transformer_blocks))
]
)
if not self.config.flux_only_double:
self.adapter_modules = torch.nn.ModuleList(
[
transformer.transformer_blocks[i].attn.processor for i in
range(len(transformer.transformer_blocks))
] + [
transformer.single_transformer_blocks[i].attn.processor for i in
range(len(transformer.single_transformer_blocks))
]
)
else:
self.adapter_modules = torch.nn.ModuleList(
[
transformer.transformer_blocks[i].attn.processor for i in
range(len(transformer.transformer_blocks))
]
)
else:
sd.unet.set_attn_processor(attn_procs)
self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values())