mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Added support to disable single transformers in vision direct adapter
This commit is contained in:
@@ -203,6 +203,8 @@ class AdapterConfig:
|
|||||||
self.ilora_mid: bool = kwargs.get('ilora_mid', True)
|
self.ilora_mid: bool = kwargs.get('ilora_mid', True)
|
||||||
self.ilora_up: bool = kwargs.get('ilora_up', True)
|
self.ilora_up: bool = kwargs.get('ilora_up', True)
|
||||||
|
|
||||||
|
self.flux_only_double: bool = kwargs.get('flux_only_double', False)
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingConfig:
|
class EmbeddingConfig:
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
|
|||||||
@@ -616,19 +616,28 @@ class VisionDirectAdapter(torch.nn.Module):
|
|||||||
for i, module in transformer.transformer_blocks.named_children():
|
for i, module in transformer.transformer_blocks.named_children():
|
||||||
module.attn.processor = attn_procs[f"transformer_blocks.{i}.attn"]
|
module.attn.processor = attn_procs[f"transformer_blocks.{i}.attn"]
|
||||||
|
|
||||||
# do single blocks too even though they dont have cross attn
|
if not self.config.flux_only_double:
|
||||||
for i, module in transformer.single_transformer_blocks.named_children():
|
# do single blocks too even though they dont have cross attn
|
||||||
module.attn.processor = attn_procs[f"single_transformer_blocks.{i}.attn"]
|
for i, module in transformer.single_transformer_blocks.named_children():
|
||||||
|
module.attn.processor = attn_procs[f"single_transformer_blocks.{i}.attn"]
|
||||||
|
|
||||||
self.adapter_modules = torch.nn.ModuleList(
|
if not self.config.flux_only_double:
|
||||||
[
|
self.adapter_modules = torch.nn.ModuleList(
|
||||||
transformer.transformer_blocks[i].attn.processor for i in
|
[
|
||||||
range(len(transformer.transformer_blocks))
|
transformer.transformer_blocks[i].attn.processor for i in
|
||||||
] + [
|
range(len(transformer.transformer_blocks))
|
||||||
transformer.single_transformer_blocks[i].attn.processor for i in
|
] + [
|
||||||
range(len(transformer.single_transformer_blocks))
|
transformer.single_transformer_blocks[i].attn.processor for i in
|
||||||
]
|
range(len(transformer.single_transformer_blocks))
|
||||||
)
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.adapter_modules = torch.nn.ModuleList(
|
||||||
|
[
|
||||||
|
transformer.transformer_blocks[i].attn.processor for i in
|
||||||
|
range(len(transformer.transformer_blocks))
|
||||||
|
]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
sd.unet.set_attn_processor(attn_procs)
|
sd.unet.set_attn_processor(attn_procs)
|
||||||
self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values())
|
self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values())
|
||||||
|
|||||||
Reference in New Issue
Block a user