Added support to disable single transformers in vision direct adapter

This commit is contained in:
Jaret Burkett
2024-09-11 08:54:51 -06:00
parent fc34a69bec
commit 951e223481
2 changed files with 23 additions and 12 deletions

View File

@@ -203,6 +203,8 @@ class AdapterConfig:
self.ilora_mid: bool = kwargs.get('ilora_mid', True) self.ilora_mid: bool = kwargs.get('ilora_mid', True)
self.ilora_up: bool = kwargs.get('ilora_up', True) self.ilora_up: bool = kwargs.get('ilora_up', True)
self.flux_only_double: bool = kwargs.get('flux_only_double', False)
class EmbeddingConfig: class EmbeddingConfig:
def __init__(self, **kwargs): def __init__(self, **kwargs):

View File

@@ -616,19 +616,28 @@ class VisionDirectAdapter(torch.nn.Module):
for i, module in transformer.transformer_blocks.named_children(): for i, module in transformer.transformer_blocks.named_children():
module.attn.processor = attn_procs[f"transformer_blocks.{i}.attn"] module.attn.processor = attn_procs[f"transformer_blocks.{i}.attn"]
# do single blocks too even though they dont have cross attn if not self.config.flux_only_double:
for i, module in transformer.single_transformer_blocks.named_children(): # do single blocks too even though they dont have cross attn
module.attn.processor = attn_procs[f"single_transformer_blocks.{i}.attn"] for i, module in transformer.single_transformer_blocks.named_children():
module.attn.processor = attn_procs[f"single_transformer_blocks.{i}.attn"]
self.adapter_modules = torch.nn.ModuleList( if not self.config.flux_only_double:
[ self.adapter_modules = torch.nn.ModuleList(
transformer.transformer_blocks[i].attn.processor for i in [
range(len(transformer.transformer_blocks)) transformer.transformer_blocks[i].attn.processor for i in
] + [ range(len(transformer.transformer_blocks))
transformer.single_transformer_blocks[i].attn.processor for i in ] + [
range(len(transformer.single_transformer_blocks)) transformer.single_transformer_blocks[i].attn.processor for i in
] range(len(transformer.single_transformer_blocks))
) ]
)
else:
self.adapter_modules = torch.nn.ModuleList(
[
transformer.transformer_blocks[i].attn.processor for i in
range(len(transformer.transformer_blocks))
]
)
else: else:
sd.unet.set_attn_processor(attn_procs) sd.unet.set_attn_processor(attn_procs)
self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values()) self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values())