mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
WIP Ilora
This commit is contained in:
@@ -145,7 +145,8 @@ class CustomAdapter(torch.nn.Module):
|
||||
self.ilora_module = InstantLoRAModule(
|
||||
vision_tokens=vision_tokens,
|
||||
vision_hidden_size=vision_hidden_size,
|
||||
head_dim=1024,
|
||||
head_dim=self.config.head_dim,
|
||||
num_heads=self.config.num_heads,
|
||||
sd=self.sd_ref()
|
||||
)
|
||||
elif self.adapter_type == 'text_encoder':
|
||||
@@ -878,6 +879,11 @@ class CustomAdapter(torch.nn.Module):
|
||||
self.vision_encoder.gradient_checkpointing = True
|
||||
|
||||
def get_additional_save_metadata(self) -> Dict[str, Any]:
|
||||
additional = {}
|
||||
if self.config.type == 'ilora':
|
||||
return self.ilora_module.get_additional_save_metadata()
|
||||
return {}
|
||||
extra = self.ilora_module.get_additional_save_metadata()
|
||||
for k, v in extra.items():
|
||||
additional[k] = v
|
||||
additional['clip_layer'] = self.config.clip_layer
|
||||
additional['image_encoder_arch'] = self.config.head_dim
|
||||
return additional
|
||||
Reference in New Issue
Block a user