diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 4950868c..b8c82c8e 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -203,6 +203,8 @@ class AdapterConfig: self.ilora_mid: bool = kwargs.get('ilora_mid', True) self.ilora_up: bool = kwargs.get('ilora_up', True) + self.flux_only_double: bool = kwargs.get('flux_only_double', False) + class EmbeddingConfig: def __init__(self, **kwargs): diff --git a/toolkit/models/vd_adapter.py b/toolkit/models/vd_adapter.py index 11612134..bdbd108a 100644 --- a/toolkit/models/vd_adapter.py +++ b/toolkit/models/vd_adapter.py @@ -616,19 +616,28 @@ class VisionDirectAdapter(torch.nn.Module): for i, module in transformer.transformer_blocks.named_children(): module.attn.processor = attn_procs[f"transformer_blocks.{i}.attn"] - # do single blocks too even though they dont have cross attn - for i, module in transformer.single_transformer_blocks.named_children(): - module.attn.processor = attn_procs[f"single_transformer_blocks.{i}.attn"] + if not self.config.flux_only_double: + # do single blocks too even though they dont have cross attn + for i, module in transformer.single_transformer_blocks.named_children(): + module.attn.processor = attn_procs[f"single_transformer_blocks.{i}.attn"] - self.adapter_modules = torch.nn.ModuleList( - [ - transformer.transformer_blocks[i].attn.processor for i in - range(len(transformer.transformer_blocks)) - ] + [ - transformer.single_transformer_blocks[i].attn.processor for i in - range(len(transformer.single_transformer_blocks)) - ] - ) + if not self.config.flux_only_double: + self.adapter_modules = torch.nn.ModuleList( + [ + transformer.transformer_blocks[i].attn.processor for i in + range(len(transformer.transformer_blocks)) + ] + [ + transformer.single_transformer_blocks[i].attn.processor for i in + range(len(transformer.single_transformer_blocks)) + ] + ) + else: + self.adapter_modules = torch.nn.ModuleList( + [ + transformer.transformer_blocks[i].attn.processor for i in + range(len(transformer.transformer_blocks)) + ] + ) else: sd.unet.set_attn_processor(attn_procs) self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values())